In [1]:
'''
## Overview
This notebook summarizes the analysis for the ring-attractor network
model for zebrafish


### Part I: Precise theoretical predictions vs simulation results
The notebook examines the correlation between the (a) activity
difference between the right and left rings and (b) angular head velocity.

The notebook also examines other predictions: activity symmetry. See 
"Manuscript_Supplementary/supplementary.pdf" for details.


### Part II: Approximate theoretical prediction vs simulation results
The notebook draws a scatter plot to see if the approximate theoretical 
predictions fit the simulation results.
The theoretical formula: right-left difference = slope * k omega, in the
derivation using first order approximation, slope depends on 
WD = Mean(W_s) - Mean(W_d), and when WD = 0, slope = 2. See 
"Manuscript_Supplementary/3. Shifter-ring network" for details. 

I haven't plot the results for the non-uniform network, since it requires more
adjustments, and I don't expect to get a extremely different result.


## Author
Siyuan Mei (mei@bio.lmu.de)


## Last update
2025-9-11: add docstring for the notebook
'''
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import ParameterGrid
from scipy.integrate import solve_ivp
from matplotlib import colors
from scipy.signal import find_peaks
from scipy.stats import pearsonr
import pandas as pd
import seaborn as sns

import HD_utils.circular_stats as cstat
from HD_utils.network import *
from HD_utils.matrix import *
from HD_utils.adap_sim_move import *
from HD_utils.adap_sim_stable import *
from HD_utils.IO import *
from HD_utils.plot import *
from HD_utils.comput_property import *
from HD_utils.exam import *
import HD_utils.unequalHD as ueHD

import functions_21 as funcs

pd.options.display.max_columns = 100

In [2]:
cor_array_list = [[], [], [], []]  # A, B, C, D
rl_sym_dev_list = []
sym_sym_dev_list = []

# Uniform HD Network

In [3]:
# Simulation theta precision
theta_num = 50
dtheta = (2*np.pi)/theta_num
theta_range = np.arange(-np.pi+dtheta/2, np.pi, dtheta) # must use np.arange(-np.pi+dtheta/2, np.pi, dtheta) to make it symmetry
# Changeable parameters
ring_num = 3
actfun = max0x
weight_fun = vonmises_weight_2i1r_2
search_pars = {'JI': np.linspace(-50,0,6), 'JE': np.linspace(0,50,6), 'K0': np.linspace(-50,0,6), 'kappa': np.logspace(-0.2,1,6)}
file_pre_name = 'new_12'
# Default parameters
inputs = np.array([-1, -0.6, -0.3, -0.1, 0, 0.1, 0.3, 0.6, 1])
net_diff_equa = net_diff_equa_f_in
phi = -np.pi * 8/9
tau = 20 # ms
b0 = 1
bc = 1
bs = [bc, b0]
# Generated parameters
par_num = len(search_pars)
search_num = len(ParameterGrid(search_pars))
zeroid = np.where(inputs == 0)[0][0]
par_names = list(search_pars.keys())

network_evals, network_evaldes, network_acvs, network_pars, network_ts = load_pickle(
    ['evals', 'eval_des', 'acvs', 'pars', 'ts'], weight_fun, actfun, '90')
Vels, network_eval_moving, network_vvcor, network_acvs_moving, network_ts_moving, network_eval_moving_sum = load_pickle(
    ['moving_slope', 'moving_eval', 'moving_eval_des', 'moving_acvs', 'moving_ts', 'moving_eval_sum'], weight_fun, actfun, file_pre_name)

In [4]:
stat_3ring_uniform = pd.DataFrame(columns=['Class', 'Ring No.', 'Weight Func equality', 'Weight Func', 'Act Func', 'N', 'N: valid stationary', 'N: linearly integrates', 'Sign of corr', 'N: RL Corr (A)', 'Mean: RL Corr (A)', 'SD: RL Corr (A)', 'Min abs RL Corr (A)', 'N: RL Corr (B)', 'Mean: RL Corr (B)', 'SD: RL Corr (B)', 'Min abs RL Corr (B)', 'N: RL Corr (C)', 'Mean: RL Corr (C)', 'SD: RL Corr (C)', 'Min abs RL Corr (C)', 'N: RL Corr (D)', 'Mean: RL Corr (D)', 'SD: RL Corr (D)', 'Min abs RL Corr (D)',  'Mean: RL Sym $d_{rel}$', 'SD: RL Sym $d_{rel}$', 'Max: RL Sym $d_{rel}$', 'Mean: Sym ring Sym $d_{rel}$', 'SD: Sym ring Sym $d_{rel}$', 'Max: Sym ring Sym $d_{rel}$'])

structure = 'Zebrafish (uniform HD)'
wfun_eq = False
    
# Variables calculation
total_num = len(network_evals)
valid_index_s = np.where(network_evals == 'valid')[0]
valid_num = len(valid_index_s)
stable_mov_range, stable_mov_range_id, linear_mov_range, linear_mov_range_id = cal_linear_range(network_eval_moving, Vels, inputs, valid_index_s)
valid_index_part_linear = np.where( linear_mov_range[:,1] > 0.1 )[0]
valid_index_linear_move = np.where( linear_mov_range[:,1] == 1 )[0]
valid_index_stable_move = np.where( stable_mov_range[:,1] == 1 )[0]
linear_num = len(valid_index_linear_move)

# Mirror symmetry of the left and right rings
index_shape_mismatch, dev_shape_ratios, if_match = cal_lr_shape_match_loop(network_acvs_moving, valid_index_stable_move, zeroid)
dev_shape_ratios = dev_shape_ratios[valid_index_stable_move]

lr_match_pro, mean_lrmatch_dev, sd_lrmatch_dev, max_lrmatch_dev = 100-len(index_shape_mismatch)/len(valid_index_stable_move) * 100, \
    np.mean(dev_shape_ratios), np.std(dev_shape_ratios), np.max(dev_shape_ratios)
rl_sym_dev_list.append(dev_shape_ratios)

# Mirror symmetry of the central ring
index_shape_mismatch, dev_ratios, if_match = cal_central_shape_match_loop(network_acvs_moving, valid_index_stable_move, zeroid)
dev_ratios = dev_ratios[valid_index_stable_move]
mean_cmatch_dev, sd_cmatch_dev, max_cmatch_dev = np.mean(dev_ratios), np.std(dev_ratios), np.max(dev_ratios)
sym_sym_dev_list.append(dev_ratios)

for i in range(2):
    sign = 1 if i == 0 else -1
    corr_sign = 'positive' if i == 0 else 'negative'

    # Input - R-L correlation
    bump_amplitudes = cal_firate_a_acv_mean_a_peak(network_acvs_moving, inputs, valid_index_part_linear, bs, actfun, kind='zebrafish')
    input_diff_cors, input_diff_ps = cal_input_diff_cor(inputs, bump_amplitudes[4:], valid_index_part_linear, linear_mov_range_id) 
    print(np.sum(input_diff_cors<0))
    show_value = input_diff_cors[valid_index_part_linear]
    show_value[show_value*-sign > 0] = np.nan
    cor_list = []
    for j in range(4):
        cor_list.append(np.sum(show_value[:,j]*sign > 0))
        cor_list.append(np.nanmean(show_value[:,j]))
        cor_list.append(np.nanstd(show_value[:,j]))
        cor_list.append(np.nanmin(np.abs(show_value[:,j])))
        
        cor_array_list[j].append(show_value[:,j])

    append = [mean_lrmatch_dev, sd_lrmatch_dev, max_lrmatch_dev, mean_cmatch_dev, sd_cmatch_dev, max_cmatch_dev] if i == 0 else [np.nan] * 6
    stat_3ring_uniform.loc[i] = [structure, ring_num, wfun_eq, waf_df_names[weight_fun.__name__], waf_df_names[actfun.__name__], total_num, valid_num, linear_num, corr_sign] + cor_list + append
    # display(stat_3ring_uniform)

stat_3ring_uniform.iloc[:,1:]

76
76


  cor_list.append(np.nanmean(show_value[:,j]))
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  cor_list.append(np.nanmin(np.abs(show_value[:,j])))


Unnamed: 0,Ring No.,Weight Func equality,Weight Func,Act Func,N,N: valid stationary,N: linearly integrates,Sign of corr,N: RL Corr (A),Mean: RL Corr (A),SD: RL Corr (A),Min abs RL Corr (A),N: RL Corr (B),Mean: RL Corr (B),SD: RL Corr (B),Min abs RL Corr (B),N: RL Corr (C),Mean: RL Corr (C),SD: RL Corr (C),Min abs RL Corr (C),N: RL Corr (D),Mean: RL Corr (D),SD: RL Corr (D),Min abs RL Corr (D),Mean: RL Sym $d_{rel}$,SD: RL Sym $d_{rel}$,Max: RL Sym $d_{rel}$,Mean: Sym ring Sym $d_{rel}$,SD: Sym ring Sym $d_{rel}$,Max: Sym ring Sym $d_{rel}$
0,3,False,vonMises,max0x,1296,390,246,positive,254,0.999842,0.000953,0.985983,178,0.997401,0.025033,0.66948,254,0.999842,0.000953,0.985983,254,0.999823,0.000418,0.996801,1e-06,2.2e-05,0.00112,5.058037e-07,1.5e-05,0.000949
1,3,False,vonMises,max0x,1296,390,246,negative,0,,,,76,-0.990435,0.057667,0.500382,0,,,,0,,,,,,,,,


# Non Uniform HD Network

In [7]:
# Simulation theta precision
theta_num = 50

dtheta_sym = (2*np.pi)/theta_num
theta_range_sym = np.arange(-np.pi+dtheta_sym/2, np.pi, dtheta_sym)

dtheta = np.repeat(dtheta_sym, theta_num)
theta_range = [theta_range_sym]

shifts = [-np.pi/2, np.pi/2]
for shift in shifts:
    dthetai, theta_rangei = ueHD.asy_theta_range(theta_range_sym, shift)
    dtheta = np.concatenate((dtheta, dthetai))
    theta_range.append(theta_rangei) # in the order: center, left, right

# Changeable parameters
ring_num = 3
actfun = max0x
weight_fun = vonmises_weight_2i1r_unequal_theta
search_pars = {'JI': np.linspace(-50,0,6), 'JE': np.linspace(0,50,6), 'K0': np.linspace(-50,0,6), 'kappa': np.logspace(-0.2,1,6)}
file_pre_name = 'new_15_1_copy'
# Default parameters
inputs = np.array([-1, -0.6, -0.3, -0.1, 0, 0.1, 0.3, 0.6, 1])
net_diff_equa = net_diff_equa_f_in
phi = -np.pi * 8/9
tau = 20 # ms
b0 = 1
bc = 1
# Generated parameters
par_num = len(search_pars)
search_num = len(ParameterGrid(search_pars))
zeroid = np.where(inputs == 0)[0][0]
par_names = list(search_pars.keys())
input_num = len(inputs)

network_evals, network_evaldes, network_acvs, network_pars, network_ts = load_pickle(
    ['evals', 'eval_des', 'acvs', 'pars', 'ts'], weight_fun, actfun, file_pre_name)

Vels, network_eval_moving, network_eval_moving_des, network_acvs_moving, network_ts_moving, network_eval_moving_sum = load_pickle(
    ['moving_slope', 'moving_eval', 'moving_eval_des', 'moving_acvs', 'moving_ts', 'moving_eval_sum'], weight_fun, actfun, file_pre_name)

In [8]:
stat_3ring_nonuniform = pd.DataFrame(columns=['Class', 'Ring No.', 'Weight Func equality', 'Weight Func', 'Act Func', 'N', 'N: valid stationary', 'N: linearly integrates', 'Sign of corr', 'N: RL Corr (A)', 'Mean: RL Corr (A)', 'SD: RL Corr (A)', 'Min abs RL Corr (A)', 'N: RL Corr (B)', 'Mean: RL Corr (B)', 'SD: RL Corr (B)', 'Min abs RL Corr (B)', 'N: RL Corr (C)', 'Mean: RL Corr (C)', 'SD: RL Corr (C)', 'Min abs RL Corr (C)', 'N: RL Corr (D)', 'Mean: RL Corr (D)', 'SD: RL Corr (D)', 'Min abs RL Corr (D)',  'Mean: RL Sym $d_{rel}$', 'SD: RL Sym $d_{rel}$', 'Max: RL Sym $d_{rel}$', 'Mean: Sym ring Sym $d_{rel}$', 'SD: Sym ring Sym $d_{rel}$', 'Max: Sym ring Sym $d_{rel}$'])

structure = 'Zebrafish (non-uniform HD)'
wfun_eq = False
bs = [bc, b0]
    
# Variables calculation
total_num = len(network_evals)
valid_index_s = np.where(network_evals == 'valid')[0]
valid_num = len(valid_index_s)
stable_mov_range, stable_mov_range_id, linear_mov_range, linear_mov_range_id = cal_linear_range(network_eval_moving, Vels, inputs, valid_index_s)
valid_index_part_linear = np.where( linear_mov_range[:,1] > 0.1 )[0]
valid_index_linear_move = np.where( linear_mov_range[:,1] == 1 )[0]
valid_index_stable_move = np.where( stable_mov_range[:,1] == 1 )[0]
linear_num = len(valid_index_linear_move)



# Mirror symmetry of the left and right rings
index_shape_mismatch, dev_shape_ratios, if_match = cal_lr_shape_match_loop(network_acvs_moving, valid_index_stable_move, zeroid)
dev_shape_ratios = dev_shape_ratios[valid_index_stable_move]

lr_match_pro, mean_lrmatch_dev, sd_lrmatch_dev, max_lrmatch_dev = 100-len(index_shape_mismatch)/len(valid_index_stable_move) * 100, \
    np.mean(dev_shape_ratios), np.std(dev_shape_ratios), np.max(dev_shape_ratios)
rl_sym_dev_list.append(dev_shape_ratios)

# Mirror symmetry of the central ring
index_shape_mismatch, dev_ratios, if_match = cal_central_shape_match_loop(network_acvs_moving, valid_index_stable_move, zeroid)
dev_ratios = dev_ratios[valid_index_stable_move]
mean_cmatch_dev, sd_cmatch_dev, max_cmatch_dev = np.mean(dev_ratios), np.std(dev_ratios), np.max(dev_ratios)
sym_sym_dev_list.append(dev_ratios)

for i in range(2):
    sign = 1 if i == 0 else -1
    corr_sign = 'positive' if i == 0 else 'negative'

    # Input - R-L correlation
    bump_amplitudes = cal_firate_a_acv_mean_a_peak(network_acvs_moving, inputs, valid_index_part_linear, bs, actfun, kind='zebrafish', weights=dtheta.reshape(3,theta_num))
    input_diff_cors, input_diff_ps = cal_input_diff_cor(inputs, bump_amplitudes[4:], valid_index_part_linear, linear_mov_range_id) 
    show_value = input_diff_cors[valid_index_part_linear]
    show_value[show_value*-sign > 0] = np.nan
    cor_list = []
    for j in range(4):
        cor_list.append(np.sum(show_value[:,j]*sign > 0))
        cor_list.append(np.nanmean(show_value[:,j]))
        cor_list.append(np.nanstd(show_value[:,j]))
        cor_list.append(np.nanmin(np.abs(show_value[:,j])))
        
        cor_array_list[j].append(show_value[:,j])

    append = [mean_lrmatch_dev, sd_lrmatch_dev, max_lrmatch_dev, mean_cmatch_dev, sd_cmatch_dev, max_cmatch_dev] if i == 0 else [np.nan]*6
    stat_3ring_nonuniform.loc[i] = [structure, ring_num, wfun_eq, waf_df_names[weight_fun.__name__], waf_df_names[actfun.__name__], total_num, valid_num, linear_num, corr_sign] + cor_list + append
    # display(stat_3ring_uniform)

stat_3ring_nonuniform.iloc[:,1:]

  cor_list.append(np.nanmean(show_value[:,j]))
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  cor_list.append(np.nanmin(np.abs(show_value[:,j])))


Unnamed: 0,Ring No.,Weight Func equality,Weight Func,Act Func,N,N: valid stationary,N: linearly integrates,Sign of corr,N: RL Corr (A),Mean: RL Corr (A),SD: RL Corr (A),Min abs RL Corr (A),N: RL Corr (B),Mean: RL Corr (B),SD: RL Corr (B),Min abs RL Corr (B),N: RL Corr (C),Mean: RL Corr (C),SD: RL Corr (C),Min abs RL Corr (C),N: RL Corr (D),Mean: RL Corr (D),SD: RL Corr (D),Min abs RL Corr (D),Mean: RL Sym $d_{rel}$,SD: RL Sym $d_{rel}$,Max: RL Sym $d_{rel}$,Mean: Sym ring Sym $d_{rel}$,SD: Sym ring Sym $d_{rel}$,Max: Sym ring Sym $d_{rel}$
0,3,False,vonMises,max0x,1296,331,54,positive,119,0.996949,0.017168,0.828555,78,0.998585,0.003921,0.982326,119,0.996949,0.017168,0.828555,119,0.998712,0.003518,0.980263,6.097502e-12,5.555995e-11,1.543797e-09,4.479109e-12,6.039457e-11,2.589959e-09
1,3,False,vonMises,max0x,1296,331,54,negative,0,,,,41,-0.984298,0.076998,0.499274,0,,,,0,,,,,,,,,


In [9]:
stat_3ring = pd.concat([stat_3ring_uniform, stat_3ring_nonuniform], ignore_index=True)
stat_3ring.loc[len(stat_3ring)] = ['Zebrafish (combined)', ring_num, None, waf_df_names[weight_fun.__name__], waf_df_names[actfun.__name__],
    stat_3ring['N'].sum()//2, stat_3ring['N: valid stationary'].sum()//2, stat_3ring['N: linearly integrates'].sum()//2, 
    None,
    None, np.nanmean(np.concatenate(cor_array_list[0])), np.nanstd(np.concatenate(cor_array_list[0])), np.nanmin(np.concatenate(cor_array_list[0])), 
    None, np.nanmean(np.concatenate(cor_array_list[1])), np.nanstd(np.concatenate(cor_array_list[1])), np.nanmin(np.concatenate(cor_array_list[1])), 
    None, np.nanmean(np.concatenate(cor_array_list[2])), np.nanstd(np.concatenate(cor_array_list[2])), np.nanmin(np.concatenate(cor_array_list[2])), 
    None, np.nanmean(np.concatenate(cor_array_list[3])), np.nanstd(np.concatenate(cor_array_list[3])), np.nanmin(np.concatenate(cor_array_list[3])), 
    np.nanmean(np.concatenate(rl_sym_dev_list)), np.nanstd(np.concatenate(rl_sym_dev_list)), np.nanmax(np.concatenate(rl_sym_dev_list)), 
    np.nanmean(np.concatenate(sym_sym_dev_list)), np.nanstd(np.concatenate(sym_sym_dev_list)), np.nanmax(np.concatenate(sym_sym_dev_list))]
# stat_3ring.to_csv(TABLE_PATH / 'Supplementary table 6. Simulation results of zebrafish HD model.csv', index=False)
stat_3ring.to_pickle(STAT_PATH / 'stat_zebrafish.npy')