In [None]:
using ProgressMeter, PyCall, PyPlot, Cairo, Images, HDF5, MultivariateStats, Interpolations, Lasso, Distributions, ImageFiltering, Random
using _Data
using  NaNStatistics, Statistics
#using ScikitLearn
include("../../project_place_cell/functions/func_map.jl")
include("../../Decoding/Decoder_Functions.jl")
include("../../Decoding/Decoder_Pipeline.jl")
np = pyimport("numpy")

rc_params = PyDict(pyimport("matplotlib")["rcParams"]);
rc_params["font.sans-serif"] = ["Arial"];
rc_params["font.size"] = 7;
rc_params["lines.linewidth"] = 1;
rc_params["lines.markersize"] = 4;
rc_params["xtick.major.size"] = 2;
rc_params["ytick.major.size"] = 2;
rc_params["axes.spines.top"] = false;
rc_params["axes.spines.right"] = false;

rc_params["xtick.major.pad"] = 2;
rc_params["ytick.major.pad"] = 2;

rc_params["axes.labelpad"] = 2;

cim(img::Matrix{UInt32}) = CairoImageSurface(img, Cairo.FORMAT_RGB24; flipxy = false) 
cim(img::Matrix{ARGB32}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{RGB24}) = cim(reinterpret(UInt32, img))
cim(img::Matrix{UInt8}) = cim(Gray.(reinterpret(N0f8, img)))
cim(img::Array{UInt8,3}) = cim(RGB24.(reinterpret(N0f8, img[:,:,1]), reinterpret(N0f8, img[:,:,2]), reinterpret(N0f8, img[:,:,3])));downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);
downsample(img,n=4) = +([img[i:n:n*div(size(img,1)-1,n), j:n:n*div(size(img,2)-1,n)] for i = 1:n, j = 1:n]...)/(n*n);

## import

In [None]:
datasets_corner_cue = 
[
    ["20220407_152537", 4, "jen"],
    ["20220406_111526", 9, "jen"],
    ["20220407_090156", 5, "jen"],
    ["20220417_165530", 25, "jen"],
    ["20220406_153842", 9, "jen"],
    ["20220405_171444", 25, "jen"],
    ["20220416_160516", 6, "jen"]
];

chuyu_server = [4, 9, 5, 2, 9, 4, 6]

lengths = [90, 90, 90, 89, 90, 90, 90, 90, 90];

In [None]:
activity_bins = 7
activity_shift = 4
n_pos = 60 #number of bins in long side
long_axis_in_mm = 47 #47 for rectangular, 33 for others
use_amount = 1000


i=1


In [None]:

experiment_filename = datasets_corner_cue[i][1]
server = datasets_corner_cue[i][2]
experimenter = datasets_corner_cue[i][3]

ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(chuyu_server[i])" ? "/data" : "/nfs/data$(chuyu_server[i])")

ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


C, heading, img_bg = h5open(ds, "behavior.h5"; raw = true) do file
    read(file, "C"),
    read(file, "heading"),
    read(file, "img_bg")
end;



# orientation-corrected fish location (time binned)
position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n$(n_pos).h5"))
    chamber_roi = read(position_file,"chamber_roi")
    x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
    y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    speed_mm_s = read(position_file, "speed_mm_s")
    loc_digital = read(position_file, "loc_digital")
    x_digital = read(position_file, "x_digital")
    y_digital = read(position_file, "y_digital")
    x_bins = read(position_file, "x_bins")
    y_bins = read(position_file, "y_bins")
valid_moving_indices = read(position_file, "valid_moving_indices")
close(position_file)

moving_valid = speed_mm_s .> 0.1;

min_x = floor(Int64, minimum(x_fish_sweep_mean));
max_x = floor(Int64, maximum(x_fish_sweep_mean));

min_y = floor(Int64, minimum(y_fish_sweep_mean));
max_y = floor(Int64, maximum(y_fish_sweep_mean));

interval = maximum([(max_y-min_y+2)/n_pos,(max_x-min_x+2)/n_pos])


file = h5open(joinpath(data_path(ds_Chuyu), "exclude_beginning_end/neuron_spatial_info_15_75_chamber_geometry_$(experiment_filename)_sigma1_n60_A_dF.h5"), "r") #_whole #spatial_info_4 done on merged cells
    # print(keys(file))
    place_cell_index = read(file, "place_cell_index")
    valid_neurons = HDF5.readmmap(file["valid_neurons"])
    specificity_shuffle_z = HDF5.readmmap(file["specificity_shuffle_z"])
    specificity_population_z = HDF5.readmmap(file["specificity_population_z"])
    specificity = HDF5.readmmap(file["specificity"])
    specificity_shuffle_p = read(file, "specificity_shuffle_p")
    bool_index = HDF5.readmmap(file["bool_index"])
close(file)

bool_index = BitArray(bool_index)



file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
    A_dFF = HDF5.readmmap(file["A_dF"])
    #A_dFF = HDF5.readmmap(file["A_dFF"])

    z_all = HDF5.readmmap(file["Z_all"])
    centroid_x_all = HDF5.readmmap(file["X_all"])
    centroid_y_all = HDF5.readmmap(file["Y_all"])
    neuron_label = read(file, "neuron_label")
close(file)

n_neurons = size(A_dFF, 2)
n_sweeps = size(A_dFF, 1)



function bin_to_px(x, y, offset=true)
    if offset
        return (x .- 0.5) .* interval .+ (min_x-1), (y .- 0.5) .* interval .+ (min_y-1)
    else
        return (x .- 0.5) .* interval, (y .- 0.5) .* interval
    end
end

function px_to_bin(x, y)
    return 0.5 .+ ((x .- (min_x-1)) ./ interval), 0.5 .+ ((y .- (min_y-1)) ./ interval)
end

x_in_bins, y_in_bins = px_to_bin(x_fish_sweep_mean, y_fish_sweep_mean);



    
    
        # import an process OT and OB
    file = h5open(joinpath(data_path(ds_Chuyu), "region_roi_bool.h5"))
        region_names = read(file, "region_names")
        region_roi_bool = read(file, "region_roi_bool")
    close(file)


    tel_index = findall(region_names .== "Telencephalon -")[1];
    mask_tel = falses(n_neurons)

    for which_neuron in 1:n_neurons
        mask_tel[which_neuron] = maximum(region_roi_bool[neuron_label.==which_neuron, tel_index])
    end

    mask_tel = findall(mask_tel)

    place_cell_index = intersect(place_cell_index, mask_tel)

In [None]:
start_mask = copy(valid_moving_indices)
start_mask[30*120:end] .= false
end_mask = copy(valid_moving_indices)
end_mask[1:60*120] .= false

Random.seed!(3)

# flip some trues in the mask to make it even
for l in unique(loc_digital)
    temp = intersect(findall(loc_digital .== l), findall(valid_moving_indices))
    start_occ = temp[temp .< 30*120]
    end_occ = temp[temp .> 60*120]

    d = length(start_occ) - length(end_occ)
    if d < 0
        end_mask[sample(end_occ, -d, replace=false)] .= false
    else
        start_mask[sample(start_occ, d, replace=false)] .= false
    end
end


mean_map_all_early = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
count_map_all_early = fill(NaN32, n_pos, n_pos, size(A_dFF,2))

mean_map_all_late = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
count_map_all_late = fill(NaN32, n_pos, n_pos, size(A_dFF,2))

@showprogress for neuron_idx in place_cell_index

    neural_activity, which_loc = MAP.valid_activity_loc(A_dFF[:, neuron_idx], start_mask, loc_digital)
    mean_map_all_early[:, :, neuron_idx], count_map_all_early[:, :, neuron_idx], summed = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = 0, use_gaussian_filter=true, sigma=3, filter_mask = nothing)

    neural_activity, which_loc = MAP.valid_activity_loc(A_dFF[:, neuron_idx], end_mask, loc_digital)
    mean_map_all_late[:, :, neuron_idx], count_map_all_late[:, :, neuron_idx], summed = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = 0, use_gaussian_filter=true, sigma=3, filter_mask = nothing)

end

In [None]:
comp = fill(0, 2, length(place_cell_index))

comp_map_all_early = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
comp_map_all_late = fill(NaN32, n_pos, n_pos, size(A_dFF,2))

for i in 1:length(place_cell_index)

    components_peaks, img_label_valid, valid_components = map_components_peak(mean_map_all_early[:,:,place_cell_index[i]]; threshold = 8/10, components_size_threshold = 20)
    
    comp[1, i] = length(unique(img_label_valid)) - 1
    comp_map_all_early[:,:,i] = img_label_valid
    
    components_peaks, img_label_valid, valid_components = map_components_peak(mean_map_all_late[:,:,place_cell_index[i]]; threshold = 8/10, components_size_threshold = 20)
    
    comp[2, i] = length(unique(img_label_valid)) - 1
    comp_map_all_late[:,:,i] = img_label_valid
end

In [None]:
mean(comp[1,:])

In [None]:
mean(comp[2,:])

In [None]:
# normal early late
# 1.63 1.516

# subsampling
# 1.639 1.605

In [None]:
hist(comp[1,:], bins=1:7)
hist(comp[2,:], bins=1:7, alpha=0.5)

In [None]:
use_gaussian_filter = true
sigma = 1
nr_shuffle = 1000
at_least_shift = 1

if use_gaussian_filter
    n_pos = 60
    at_least_visit = 2
else
    n_pos = 15
    at_least_visit = 5
end;

chamber_roi_xy = findall(chamber_roi.!=0)
chamber_roi_x = [xy[1] for xy in chamber_roi_xy]
chamber_roi_y = [xy[2] for xy in chamber_roi_xy];
chamber_roi_x_digital = numpy.digitize(chamber_roi_x, x_bins)
chamber_roi_y_digital = numpy.digitize(chamber_roi_y, y_bins);
chamber_roi_digital = (chamber_roi_y_digital.-1).*n_pos.+chamber_roi_x_digital;
mask_valid = MAP.calculate_mask_map_digital(chamber_roi_digital,n_pos).>0;
mask_valid = erode(dilate(mask_valid))
nr_valid_pos = sum(mask_valid);
mask_invalid = .!mask_valid;

maps_early = []
maps_late = []


@showprogress for neuron_idx in place_cell_index

    neural_activity = A_dFF[:,neuron_idx]
    
    # early
    
    neural_activity, which_loc = MAP.valid_activity_loc(neural_activity, start_mask, loc_digital)

    cur_map, mask_map, activity_num_map = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = at_least_visit, use_gaussian_filter=use_gaussian_filter, sigma=sigma, filter_mask=mask_invalid)
    place_map_all_original = activity_num_map./mask_map;

    append!(maps_early, [place_map_all_original])
    
    # late
    neural_activity = A_dFF[:,neuron_idx]
    neural_activity, which_loc = MAP.valid_activity_loc(neural_activity, end_mask, loc_digital)

    cur_map, mask_map, activity_num_map = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = at_least_visit, use_gaussian_filter=use_gaussian_filter, sigma=sigma, filter_mask=mask_invalid)

    place_map_all_original = activity_num_map./mask_map;
    append!(maps_late, [place_map_all_original])
end

## early-late mode, size all

In [None]:

activity_bins = 7
activity_shift = 4
n_pos = 60 #number of bins in long side
long_axis_in_mm = 47 #47 for rectangular, 33 for others
use_amount = 1000

lengths = [90, 90, 90, 89, 90, 90, 90, 90, 90];

comps_all = []
size_all = []

for i in 1:length(datasets_corner_cue)


experiment_filename = datasets_corner_cue[i][1]
server = datasets_corner_cue[i][2]
experimenter = datasets_corner_cue[i][3]

ds_Lorenz = Dataset(experiment_filename, "lorenz", gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")
ds_Chuyu = Dataset(experiment_filename, "chuyu", gethostname() == "roli-$(chuyu_server[i])" ? "/data" : "/nfs/data$(chuyu_server[i])")

ds = Dataset(experiment_filename, experimenter, gethostname() == "roli-$(server)" ? "/data" : "/nfs/data$(server)")


C, heading, img_bg = h5open(ds, "behavior.h5"; raw = true) do file
    read(file, "C"),
    read(file, "heading"),
    read(file, "img_bg")
end;



# orientation-corrected fish location (time binned)
position_file = h5open(joinpath(data_path(ds_Chuyu), "for_place_calculation_chamber_geometry_$(experiment_filename)_n$(n_pos).h5"))
    chamber_roi = read(position_file,"chamber_roi")
    x_fish_sweep_mean = read(position_file,"x_fish_sweep_mean")
    y_fish_sweep_mean = read(position_file,"y_fish_sweep_mean")
    speed_mm_s = read(position_file, "speed_mm_s")
    loc_digital = read(position_file, "loc_digital")
    x_digital = read(position_file, "x_digital")
    y_digital = read(position_file, "y_digital")
    x_bins = read(position_file, "x_bins")
    y_bins = read(position_file, "y_bins")
valid_moving_indices = read(position_file, "valid_moving_indices")
close(position_file)

moving_valid = speed_mm_s .> 0.1;

min_x = floor(Int64, minimum(x_fish_sweep_mean));
max_x = floor(Int64, maximum(x_fish_sweep_mean));

min_y = floor(Int64, minimum(y_fish_sweep_mean));
max_y = floor(Int64, maximum(y_fish_sweep_mean));

interval = maximum([(max_y-min_y+2)/n_pos,(max_x-min_x+2)/n_pos])


file = h5open(joinpath(data_path(ds_Chuyu), "exclude_beginning_end/neuron_spatial_info_15_75_chamber_geometry_$(experiment_filename)_sigma1_n60_A_dF.h5"), "r") #_whole #spatial_info_4 done on merged cells
    # print(keys(file))
    place_cell_index = read(file, "place_cell_index")
    valid_neurons = HDF5.readmmap(file["valid_neurons"])
    specificity_shuffle_z = HDF5.readmmap(file["specificity_shuffle_z"])
    specificity_population_z = HDF5.readmmap(file["specificity_population_z"])
    specificity = HDF5.readmmap(file["specificity"])
    specificity_shuffle_p = read(file, "specificity_shuffle_p")
    bool_index = HDF5.readmmap(file["bool_index"])
close(file)

bool_index = BitArray(bool_index)



file = h5open(joinpath(data_path(ds_Chuyu), "NMF_merge.h5"), "r")
    #A_dF = HDF5.readmmap(file["A_dF"])
    A_dFF = HDF5.readmmap(file["A_dF"])

    z_all = HDF5.readmmap(file["Z_all"])
    centroid_x_all = HDF5.readmmap(file["X_all"])
    centroid_y_all = HDF5.readmmap(file["Y_all"])
    neuron_label = read(file, "neuron_label")
close(file)

n_neurons = size(A_dFF, 2)
n_sweeps = size(A_dFF, 1)



function bin_to_px(x, y, offset=true)
    if offset
        return (x .- 0.5) .* interval .+ (min_x-1), (y .- 0.5) .* interval .+ (min_y-1)
    else
        return (x .- 0.5) .* interval, (y .- 0.5) .* interval
    end
end

function px_to_bin(x, y)
    return 0.5 .+ ((x .- (min_x-1)) ./ interval), 0.5 .+ ((y .- (min_y-1)) ./ interval)
end

x_in_bins, y_in_bins = px_to_bin(x_fish_sweep_mean, y_fish_sweep_mean);



    
    
        # import an process OT and OB
    file = h5open(joinpath(data_path(ds_Chuyu), "region_roi_bool.h5"))
        region_names = read(file, "region_names")
        region_roi_bool = read(file, "region_roi_bool")
    close(file)


    tel_index = findall(region_names .== "Telencephalon -")[1];
    mask_tel = falses(n_neurons)

    for which_neuron in 1:n_neurons
        mask_tel[which_neuron] = maximum(region_roi_bool[neuron_label.==which_neuron, tel_index])
    end

    mask_tel = findall(mask_tel)

    place_cell_index = intersect(place_cell_index, mask_tel)
    
    Random.seed!(3)
    control_cell_index = rand(mask_tel, length(place_cell_index))
   
    
start_mask = copy(valid_moving_indices)
start_mask[30*120:end] .= false
end_mask = copy(valid_moving_indices)
end_mask[1:60*120] .= false

Random.seed!(3)
    
# flip some trues in the mask to make it even
for l in unique(loc_digital)
    temp = intersect(findall(loc_digital .== l), findall(valid_moving_indices))
    start_occ = temp[temp .< 30*120]
    end_occ = temp[temp .> 60*120]

    d = length(start_occ) - length(end_occ)
    if d < 0
        end_mask[sample(end_occ, -d, replace=false)] .= false
    else
        start_mask[sample(start_occ, d, replace=false)] .= false
    end
end


mean_map_all_early = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
count_map_all_early = fill(NaN32, n_pos, n_pos, size(A_dFF,2))

mean_map_all_late = fill(NaN32, n_pos, n_pos, size(A_dFF,2))
count_map_all_late = fill(NaN32, n_pos, n_pos, size(A_dFF,2))

    
@showprogress for neuron_idx in place_cell_index #1:size(A_dFF, 2)

    neural_activity, which_loc = MAP.valid_activity_loc(A_dFF[:, neuron_idx], start_mask, loc_digital)
    mean_map_all_early[:, :, neuron_idx], count_map_all_early[:, :, neuron_idx], summed = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = 0, use_gaussian_filter=true, sigma=1, filter_mask = nothing)

    neural_activity, which_loc = MAP.valid_activity_loc(A_dFF[:, neuron_idx], end_mask, loc_digital)
    mean_map_all_late[:, :, neuron_idx], count_map_all_late[:, :, neuron_idx], summed = MAP.calculate_map_direct(neural_activity, which_loc, n_pos; at_least_visit = 0, use_gaussian_filter=true, sigma=1, filter_mask = nothing)        
        
end

size_ = fill(0, 2, length(place_cell_index))
threshold = 8/10
bottom_activity = 0

for i in 1:length(place_cell_index)
neuron_idx = place_cell_index[i]
    top_activity = nanpctile(mean_map_all_early[:, :, place_cell_index[i]], 95)
    field_threshold = (top_activity-bottom_activity)*threshold +bottom_activity       
    size_[1, i] = sum((mean_map_all_early[:, :, place_cell_index[i]]) .> field_threshold)
        
    top_activity = nanpctile(mean_map_all_late[:, :, place_cell_index[i]], 95)
    field_threshold = (top_activity-bottom_activity)*threshold +bottom_activity       
    size_[2, i] = sum((mean_map_all_late[:, :, place_cell_index[i]]) .> field_threshold)

end
    
append!(size_all, [size_])

    
comp = fill(0, 2, length(place_cell_index))
for i in 1:length(place_cell_index)

    components_peaks, img_label_valid, valid_components = map_components_peak(mean_map_all_early[:,:,place_cell_index[i]]; threshold = 8/10, components_size_threshold = 20)
    
    comp[1, i] = length(unique(img_label_valid)) - 1
    
    
    components_peaks, img_label_valid, valid_components = map_components_peak(mean_map_all_late[:,:,place_cell_index[i]]; threshold = 8/10, components_size_threshold = 20)
    
    comp[2, i] = length(unique(img_label_valid)) - 1 #sum(img_label_valid .!= 0)# 
end
    
append!(comps_all, [comp])

end

In [None]:
size_early = [mean(c[1, :]) for c in size_all]
size_late = [mean(c[2, :]) for c in size_all]
size_late.-size_early

In [None]:
size_early

In [None]:
# size conversion from bin-units to mm²
(47/40)^2

In [None]:
fig, ax = subplots(figsize=(0.7,1.5))
    
ax.scatter(fill(1, length(size_early)), size_early./1.38, color="black", s=2)
ax.scatter(fill(2, length(size_late)), size_late./1.38, color="black", s=2)

for i in 1:length(size_early)
    ax.plot([1,2], [size_early[i], size_late[i]]./1.38, linewidth=0.5, color="grey", zorder=-1)
    
end

   ax.plot([1,2], [size_early[7], size_late[7]]./1.38, linewidth=0.5, color="red", zorder=-1)


ax.set_ylim(0, 150)
ax.set_xlim(0.8, 2.2)
ax.set_yticks([0, 150], labels=["0", "150"])
ax.set_xticks([1,2], labels=["Early", "Late"])
ax.set_ylabel("Field size (mm²)", labelpad=-10)
tight_layout(pad=0.2)
fig.savefig("field_size_eq.pdf", format="pdf",  transparent=true, dpi=300)


In [None]:
# delta_field_size = ende .- start
size_late

In [None]:
@pyimport scipy
for i in 1:7
    println(scipy.stats.wilcoxon(size_all[i][2,:], size_all[i][1,:], alternative="less"))
end

In [None]:
scipy.stats.wilcoxon(size_early, size_late)