In [None]:
using ProgressMeter, PyCall, PyPlot, Cairo, Images, HDF5, MultivariateStats, Interpolations, Lasso, Distributions, ImageFiltering, Random
using _Data
using  NaNStatistics, Statistics
#using ScikitLearn
include("../../project_place_cell/functions/func_map.jl")
include("../../Decoding/Decoder_Functions.jl")
include("../../Decoding/Decoder_Pipeline.jl")
np = pyimport("numpy")

rc_params = PyDict(pyimport("matplotlib")["rcParams"]);
rc_params["font.sans-serif"] = ["Arial"];
rc_params["font.size"] = 7;
rc_params["lines.linewidth"] = 1;
rc_params["lines.markersize"] = 4;
rc_params["xtick.major.size"] = 2;
rc_params["ytick.major.size"] = 2;
rc_params["axes.spines.top"] = false;
rc_params["axes.spines.right"] = false;

rc_params["xtick.major.pad"] = 2;
rc_params["ytick.major.pad"] = 2;

rc_params["axes.labelpad"] = 2;

cim(img::Matrix{UInt32}) = CairoImageSurface(img, Cairo.FORMAT_RGB24; flipxy = false) 
cim(img::Matrix{ARGB32}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{RGB24}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{UInt8}) = cim(Gray.(reinterpret(N0f8, img)))
cim(img::Array{UInt8,3}) = cim(RGB24.(reinterpret(N0f8, img[:,:,1]), reinterpret(N0f8, img[:,:,2]), reinterpret(N0f8, img[:,:,3])));downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);
downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);

In [None]:
datasets_corner_cue = 
[
    ["20220407_152537", 4, "jen"],
    ["20220406_111526", 9, "jen"],
    ["20220407_090156", 5, "jen"],
    ["20220417_165530", 25, "jen"],
    ["20220406_153842", 9, "jen"],
    ["20220405_171444", 25, "jen"],
    ["20220416_160516", 6, "jen"]
];

chuyu_server = [4, 9, 5, 2, 9, 4, 6]

lengths = [90, 90, 90, 89, 90, 90, 90, 90, 90];

## import

In [None]:
activity_bins = 7
activity_shift = 4
n_pos = 60 #number of bins in long side
long_axis_in_mm = 47 #47 for rectangular, 33 for others
use_amount = 1000


i=4

experiment_filename = datasets_corner_cue[i][1]
server = datasets_corner_cue[i][2]
experimenter = datasets_corner_cue[i][3]

ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(chuyu_server[i])" ? "/data" : "/nfs/data$(chuyu_server[i])")

ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


In [None]:

C, heading, img_bg, y_fish, x_offset, x_fish, y_offset = h5open(ds, "behavior.h5"; raw = true) do file
    read(file, "C"),
    read(file, "heading"),
    read(file, "img_bg"),
    read(file, "fish_yolk_y"),
    read(file, "offset_x"),
    read(file, "fish_yolk_x"),
    read(file, "offset_y")
end;

img_bg_end = img_bg[:,:,end]

# orientation-corrected fish location (time binned)
position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
    chamber_roi = read(position_file,"chamber_roi")
    x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
    y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    speed_mm_s = read(position_file, "speed_mm_s")
    loc_digital = read(position_file, "loc_digital")
    x_digital = read(position_file, "x_digital")
    y_digital = read(position_file, "y_digital")
    x_bins = read(position_file, "x_bins")
    y_bins = read(position_file, "y_bins")
close(position_file)

moving_valid = speed_mm_s .> 0.1;
long_axis_in_bins = (maximum(x_digital)-minimum(x_digital)+1)


file = h5open(joinpath(data_path(ds_Chuyu), "exclude_beginning_end/neuron_spatial_info_15_75_chamber_geometry_$(experiment_filename)_sigma1_n60_A_dF.h5"), "r") #_whole #spatial_info_4 done on merged cells
    valid_neurons = HDF5.readmmap(file["valid_neurons"])
    specificity_shuffle_z = HDF5.readmmap(file["specificity_shuffle_z"])
    specificity_population_z = HDF5.readmmap(file["specificity_population_z"])
    specificity = HDF5.readmmap(file["specificity"])
    bool_index = HDF5.readmmap(file["bool_index"])
close(file)

bool_index = BitArray(bool_index)



file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
    A_dFF = HDF5.readmmap(file["A_dF"])
    #A_dFF = HDF5.readmmap(file["A_dFF"])

    z_all = HDF5.readmmap(file["Z_all"])
    centroid_x_all = HDF5.readmmap(file["X_all"])
    centroid_y_all = HDF5.readmmap(file["Y_all"])

    neuron_label = read(file, "neuron_label");
close(file)

n_neurons = size(A_dFF, 2)
n_sweeps = size(A_dFF, 1)



# import an process OT and OB
file = h5open(joinpath(data_path(ds_Chuyu), "region_roi_bool.h5"))
    region_names = read(file, "region_names")
    region_roi_bool = read(file, "region_roi_bool")
close(file)


tel_index = findall(region_names .== "Telencephalon -")[1];

mask_tel = falses(n_neurons)
@showprogress for which_neuron in 1:n_neurons
    mask_tel[which_neuron] = maximum(region_roi_bool[neuron_label.==which_neuron, tel_index])
end

mask_tel = findall(mask_tel)



w = size(img_bg, 1)
l = size(img_bg, 2);

# Define bins
min_x = 0
min_y = 0
max_x = w
max_y = l

bin_interval = maximum([(max_y-min_y+2)/n_pos,(max_x-min_x+2)/n_pos])

x_bins = collect(min_x-1:bin_interval:min_x+bin_interval*(n_pos)+1);
y_bins = collect(min_y-1:bin_interval:min_y+bin_interval*(n_pos)+1);

x_bins_mid = (x_bins[1:end-1]+x_bins[2:end])/2
y_bins_mid = (y_bins[1:end-1]+y_bins[2:end])/2;

# Digitize data, then we just count the number
x_digital2 = np.digitize(x_fish_sweep_mean, x_bins)
y_digital2 = np.digitize(y_fish_sweep_mean, y_bins);
loc_digital2 = (y_digital2.-1).*n_pos.+x_digital2;


function bin_to_px(x, y, offset=true)
    if offset
        return (x .- 0.5) .* bin_interval .+ (min_x-1), (y .- 0.5) .* bin_interval .+ (min_y-1)
    else
        return (x .- 0.5) .* bin_interval, (y .- 0.5) .* bin_interval
    end
end

function px_to_bin(x, y)
    return 0.5 .+ ((x .- (min_x-1)) ./ bin_interval), 0.5 .+ ((y .- (min_y-1)) ./ bin_interval)
end

x_in_bins, y_in_bins = px_to_bin(x_fish_sweep_mean, y_fish_sweep_mean);


try
mkdir(data_path(ds_Lorenz))    
catch
    println("save path exists")
end

temp = Decoder.get_top_neurons(use_amount, specificity_population_z[mask_tel], specificity_shuffle_z[mask_tel]);
place_candidates_unique = mask_tel[temp];

A_dFF_place_cells = A_dFF[:, place_candidates_unique];

## one fish

In [None]:
readdir(string(path(ds_Chuyu), "/exclude_beginning_end"))

In [None]:
start_mask = copy(valid_moving_indices)
start_mask[30*120:end] .= false
end_mask = copy(valid_moving_indices)
end_mask[1:60*120] .= false

Random.seed!(3)

# flip some trues in the mask to make it even
for l in unique(loc_digital)
    temp = intersect(findall(loc_digital .== l), findall(valid_moving_indices))
    start_occ = temp[temp .< 30*120]
    end_occ = temp[temp .> 60*120]

    d = length(start_occ) - length(end_occ)
    if d < 0
        end_mask[sample(end_occ, -d, replace=false)] .= false
    else
        start_mask[sample(start_occ, d, replace=false)] .= false
    end
end
start_mask = Bool.(start_mask)[1:length(loc_digital)]
end_mask = Bool.(end_mask)[1:length(loc_digital)]

In [None]:
loc_digital

In [None]:
files = ["0_30", "60_90"] #["0_30", "0_60", "60_90"]
windows = [start_mask, end_mask]
errors_early_late = sizehint!(Vector{Any}(), 2)
for (window, f) in zip(windows, files)
    
    rolling_predicted, rolling_var, rolling_valid = Decoder.rolling_decoder(window[1:length(loc_digital)], A_dFF_place_cells, loc_digital, activity_bins, activity_shift, n_pos)
    
    errors = Decoder.get_distance(rolling_predicted[:, 1], x_in_bins, rolling_predicted[:, 2], y_in_bins)
    println(nanmedian(errors[window[1:length(loc_digital)]]) * long_axis_in_mm/long_axis_in_bins, " mm")
    
    push!(errors_early_late, errors)

end

## all fish early-late decoding

In [None]:

activity_bins = 7
activity_shift = 4
n_pos = 60 #number of bins in long side
long_axis_in_mm = 47 #47 for rectangular, 33 for others
use_amount = 1000

lengths = [90, 90, 90, 89, 90, 90, 90, 90, 90];

errors_all = []
errors_all_rand = []

for i in 1:length(datasets_corner_cue)


    experiment_filename = datasets_corner_cue[i][1]
    server = datasets_corner_cue[i][2]
    experimenter = datasets_corner_cue[i][3]

    ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
    ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(chuyu_server[i])" ? "/data" : "/nfs/data$(chuyu_server[i])")

    ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


    C, heading, img_bg, y_fish, x_offset, x_fish, y_offset = h5open(ds, "behavior.h5"; raw = true) do file
        read(file, "C"),
        read(file, "heading"),
        read(file, "img_bg"),
        read(file, "fish_yolk_y"),
        read(file, "offset_x"),
        read(file, "fish_yolk_x"),
        read(file, "offset_y")
    end;

    img_bg_end = img_bg[:,:,end]

    # orientation-corrected fish location (time binned)
    position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n60.h5"))
        chamber_roi = read(position_file,"chamber_roi")
        x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
        y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
        speed_mm_s = read(position_file, "speed_mm_s")
        loc_digital = read(position_file, "loc_digital")
        x_digital = read(position_file, "x_digital")
        y_digital = read(position_file, "y_digital")
        x_bins = read(position_file, "x_bins")
        y_bins = read(position_file, "y_bins")
        valid_moving_indices = read(position_file, "valid_moving_indices")
    close(position_file)

    long_axis_in_bins = (maximum(x_digital)-minimum(x_digital)+1)


    file = h5open(joinpath(data_path(ds_Chuyu), "exclude_beginning_end/neuron_spatial_info_15_75_chamber_geometry_$(experiment_filename)_sigma1_n60_A_dF.h5"), "r") #_whole #spatial_info_4 done on merged cells
        valid_neurons = HDF5.readmmap(file["valid_neurons"])
        specificity_shuffle_z = HDF5.readmmap(file["specificity_shuffle_z"])
        specificity_population_z = HDF5.readmmap(file["specificity_population_z"])
        specificity = HDF5.readmmap(file["specificity"])
        bool_index = HDF5.readmmap(file["bool_index"])
    close(file)

    bool_index = BitArray(bool_index)



    file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
        A_dFF = HDF5.readmmap(file["A_dFF"])

        z_all = HDF5.readmmap(file["Z_all"])
        centroid_x_all = HDF5.readmmap(file["X_all"])
        centroid_y_all = HDF5.readmmap(file["Y_all"])

        neuron_label = read(file, "neuron_label");
    close(file)

    n_neurons = size(A_dFF, 2)
    n_sweeps = size(A_dFF, 1)



    # import an process OT and OB
    file = h5open(joinpath(data_path(ds_Chuyu), "region_roi_bool.h5"))
        region_names = read(file, "region_names")
        region_roi_bool = read(file, "region_roi_bool")
    close(file)


    tel_index = findall(region_names .== "Telencephalon -")[1];

    mask_tel = falses(n_neurons)
    @showprogress for which_neuron in 1:n_neurons
        mask_tel[which_neuron] = maximum(region_roi_bool[neuron_label.==which_neuron, tel_index])
    end

    mask_tel = findall(mask_tel)



    w = size(img_bg, 1)
    l = size(img_bg, 2);

    # Define bins
    min_x = 0
    min_y = 0
    max_x = w
    max_y = l

    bin_interval = maximum([(max_y-min_y+2)/n_pos,(max_x-min_x+2)/n_pos])

    x_bins = collect(min_x-1:bin_interval:min_x+bin_interval*(n_pos)+1);
    y_bins = collect(min_y-1:bin_interval:min_y+bin_interval*(n_pos)+1);

    x_bins_mid = (x_bins[1:end-1]+x_bins[2:end])/2
    y_bins_mid = (y_bins[1:end-1]+y_bins[2:end])/2;

    # Digitize data, then we just count the number
    x_digital2 = np.digitize(x_fish_sweep_mean, x_bins)
    y_digital2 = np.digitize(y_fish_sweep_mean, y_bins);
    loc_digital2 = (y_digital2.-1).*n_pos.+x_digital2;


    function bin_to_px(x, y, offset=true)
        if offset
            return (x .- 0.5) .* bin_interval .+ (min_x-1), (y .- 0.5) .* bin_interval .+ (min_y-1)
        else
            return (x .- 0.5) .* bin_interval, (y .- 0.5) .* bin_interval
        end
    end

    function px_to_bin(x, y)
        return 0.5 .+ ((x .- (min_x-1)) ./ bin_interval), 0.5 .+ ((y .- (min_y-1)) ./ bin_interval)
    end

    x_in_bins, y_in_bins = px_to_bin(x_fish_sweep_mean, y_fish_sweep_mean);


    try
    mkdir(data_path(ds_Lorenz))    
    catch
        println("save path exists")
    end

    temp = Decoder.get_top_neurons(use_amount, specificity_population_z[mask_tel], specificity_shuffle_z[mask_tel]);
    place_candidates_unique = mask_tel[temp];

    A_dFF_place_cells = A_dFF[:, place_candidates_unique]

    Random.seed!(3)
    neurons = rand(setdiff(1:n_neurons, place_candidates_unique), use_amount)
    A_dFF_random_cells = A_dFF[:, neurons]
    


    start_mask = copy(valid_moving_indices)
    start_mask[30*120:end] .= false
    end_mask = copy(valid_moving_indices)
    end_mask[1:60*120] .= false

    Random.seed!(3)

    # flip some trues in the mask to make it even
    for l in unique(loc_digital)
        temp = intersect(findall(loc_digital .== l), findall(valid_moving_indices))
        start_occ = temp[temp .< 30*120]
        end_occ = temp[temp .> 60*120]

        d = length(start_occ) - length(end_occ)
        if d < 0
            end_mask[sample(end_occ, -d, replace=false)] .= false
        else
            start_mask[sample(start_occ, d, replace=false)] .= false
        end
    end
    start_mask = Bool.(start_mask)
    end_mask = Bool.(end_mask)
    

    # place cells
    
    files = ["0_30", "60_90"] #["0_30", "0_60", "60_90"]
    windows = [start_mask, end_mask]
    errors_early_late = sizehint!(Vector{Any}(), 2)
    for (window, f) in zip(windows, files)

        rolling_predicted, rolling_var, rolling_valid = Decoder.rolling_decoder(window[1:length(loc_digital)], A_dFF_place_cells, loc_digital, activity_bins, activity_shift, n_pos)

        errors = Decoder.get_distance(rolling_predicted[:, 1], x_in_bins, rolling_predicted[:, 2], y_in_bins)[window[1:length(loc_digital)]] .* long_axis_in_mm./long_axis_in_bins
        println(nanmedian(errors), " mm")

        push!(errors_early_late, errors)
    end
    
    push!(errors_all, errors_early_late)
    
    
    # random cells
        files = ["0_30", "60_90"] #["0_30", "0_60", "60_90"]
    windows = [start_mask, end_mask]
    errors_early_late_rand = sizehint!(Vector{Any}(), 2)
    for (window, f) in zip(windows, files)

        rolling_predicted, rolling_var, rolling_valid = Decoder.rolling_decoder(window[1:length(loc_digital)], A_dFF_random_cells, loc_digital, activity_bins, activity_shift, n_pos)

        errors = Decoder.get_distance(rolling_predicted[:, 1], x_in_bins, rolling_predicted[:, 2], y_in_bins)[window[1:length(loc_digital)]] .* long_axis_in_mm./long_axis_in_bins
        println(nanmedian(errors), " mm")

        push!(errors_early_late_rand, errors)
    end
    
    push!(errors_all_rand, errors_early_late_rand)
        
end

In [None]:
1+1

In [None]:
ende = [nanmedian(c[2]) for c in errors_all]
start = [nanmedian(c[1]) for c in errors_all]
ende.-start

In [None]:
ende_bas = [nanmedian(c[2]) for c in errors_all_rand]
start_bas = [nanmedian(c[1]) for c in errors_all_rand]
ende_bas .- start_bas

In [None]:
scatter(1:7, ende .- start)
scatter(1:7, ende_bas .- start_bas)

In [None]:
fig, ax = subplots(figsize=(0.7,1.5))

ax.scatter(fill(1, length(start)), start, color="black", s=2)
ax.scatter(fill(2, length(start)), ende, color="black", s=2)

for i in 1:length(start)
    ax.plot([1,2], [start[i], ende[i]], linewidth=0.5, color="grey", zorder=-1)
end


ax.set_ylim(0, 15)
ax.set_xlim(0.8, 2.2)
ax.set_yticks([0, 15], labels=["0", "15"])
ax.set_xticks([1,2], labels=["Early", "Late"])
ax.set_ylabel("Decoder error (mm)", labelpad=-7)
tight_layout(pad=0.2)
fig.savefig("F2d_decoding_early_late.pdf", format="pdf",  transparent=true, dpi=300)



In [None]:
@pyimport scipy
for i in 1:7
    a = (errors_all[i][2,:][1])
    a = a[.!isnan.(a)]
    
    b = (errors_all[i][1,:][1])
    b = b[.!isnan.(b)]
    
    println(scipy.stats.mannwhitneyu(a, b, alternative="less"))
end

In [None]:
fig, ax = subplots(3,3,figsize=(5,3.5))

for use_fish in 1:length(errors_all)
    
    p = subplot(3,3, use_fish)

    order = ["Early","Late"]
    bins= numpy.linspace(0,50,25)

    tab10 = plt.cm.get_cmap("tab10", 10)
    color1 = tab10(0)
    color2 = tab10(1)

    h1 = hist(errors_all[use_fish][1], label="Early",bins=bins, histtype="stepfilled", fc=(color1[1], color1[2], color1[3], 0.5),ec=color1)
    h2 = hist(errors_all[use_fish][2], label="Late",bins=bins, histtype="stepfilled", fc=(color2[1], color2[2], color2[3], 0.5),ec=color2)
    #ax.legend(frameon=false)
    a = round(maximum(append!(h2[1], h1[1]))/10)*10
    b = maximum(append!(h2[1], h1[1]))
    plt.text(25, maximum([a,b]), "Fish $(use_fish)")
    xlim(0,50)
    xticks([0, 10, 20, 30, 40, 50])
    xlabel("Error (mm)")
    
    ylabel("Count", labelpad=-7)
    yticks([0, round(maximum(append!(h2[1], h1[1]))/250)*250])


end
subplot(3,3, 8)
axis("off")
subplot(3,3, 9)
axis("off")

handles, labels = ax[1].get_legend_handles_labels()
ax[9].legend(handles, labels, loc="lower right", frameon=false, handlelength=1.3)

fig.subplots_adjust(hspace=10)

tight_layout(pad=0)
savefig("OLD_early_late_decoding_separate.pdf", format="pdf",  transparent=true, dpi=300)


In [None]:
order = ["Early","Late"]
bins= numpy.linspace(0,50,25)

tab10 = plt.cm.get_cmap("tab10", 10)
color1 = tab10(0)
color2 = tab10(1)

fig, ax = subplots(1,1,figsize=(1.1,1))
hist(vcat([e[1] for e in errors_all]...), label="Early",bins=bins, histtype="stepfilled", fc=(color1[1], color1[2], color1[3], 0.5),ec=color1)
hist(vcat([e[2] for e in errors_all]...), label="Late",bins=bins, histtype="stepfilled", fc=(color2[1], color2[2], color2[3], 0.5),ec=color2)
ax.legend(loc=(0.4, 0.5) ,frameon=false, handlelength=1.3)

xlim(0,50)
xticks([0, 25, 50])
yticks([0, 2000, 4000], labels=["0", "", "4000"])
xlabel("Error (mm)")
ylabel("Count", labelpad=-14)


tight_layout(pad=0)
savefig("OLD_early_late_decoding_together.pdf", format="pdf",  transparent=true, dpi=300)