In [None]:
import numpy as np
import matplotlib.pyplot as plt
from Network import Place_net, Grid_net, Coupled_Net
import brainpy as bp
import brainpy.math as bm
from matplotlib.animation import FuncAnimation
from scipy.signal import find_peaks

In [None]:
# 默认参数
# grid spacing
lambda_1 = 3
lambda_2 = 4
lambda_3 = 5
Lambda = np.array([lambda_1, lambda_2, lambda_3])
L = lambda_1 * lambda_2 * lambda_3
# cell number
num_p = int(1280)*2
rho_p = num_p/L
rho_g = rho_p
num_g = int(rho_g*2*np.pi) # 为了让两个网络的rho相等
M = len(Lambda)
# feature space
x = np.linspace(0, L, num_p, endpoint=False)
theta = np.linspace(0, 2 * np.pi, num_g, endpoint=False)
# connection range
a_p = 0.3
a_g = a_p/Lambda*2*np.pi
# connection strength
J_p = 20
J_g = J_p
J_pg = J_p/50
# divisive normalization
k_p = 20.
k_g = Lambda/2/np.pi * k_p
# time constants
tau_p = 1.
tau_g = 2*np.pi * tau_p/Lambda

Ag = 1./(4*np.sqrt(np.pi)*a_g*rho_g*k_g)*(rho_g*J_g+np.sqrt((rho_g*J_g)**2-8*np.sqrt(2*np.pi)*a_g*rho_g*k_g))
Ap = 1./(4*np.sqrt(np.pi)*a_p*rho_p*k_p)*(rho_p*J_p+np.sqrt((rho_p*J_p)**2-8*np.sqrt(2*np.pi)*a_p*rho_p*k_p))
Rg = Ag**2/(1+k_g*rho_g*a_g*np.sqrt(2*np.pi)*Ag**2)


In [None]:
# 圆周距离函数
def circ_dis(phi_1, phi_2):
    dis = phi_1 - phi_2
    dis = bm.where(dis>bm.pi, dis-2*bm.pi, dis)
    dis = bm.where(dis<-bm.pi, dis+2*bm.pi, dis)
    return dis

def Get_energy(sigma_g, sigma_phi, sigma_p, Ig, Ip, z, phi):

    def circ_dis_L(x1, x2, L):
        dis = x1 - x2
        dis = bm.where(dis>L/2, dis-L, dis)
        dis = bm.where(dis<-L/2, dis+L, dis)
        return dis

    def calculate_posterior(z, phi, Ig, Ip):
        x = np.linspace(0,L,num_p,endpoint=False)
        theta = np.linspace(0,2*np.pi,num_g,endpoint=False)
        psi_z = np.mod(z / Lambda, 1) * 2 * np.pi
        log_prior = 0
        log_likelihood_grid = 0
        for i in range(M):
            dis_1 = circ_dis(theta, phi[i])
            fg = np.exp(-dis_1**2 / (4 * a_g[i]**2))
            log_likelihood_grid -= np.sum((Ig[i, :] - fg)**2) / sigma_g[i]**2
            dis_2 = circ_dis(phi[i], psi_z[i])
            log_prior -= 1 / (sigma_phi[i]**2) * np.exp(-dis_2**2/8/a_g[i]**2) * dis_2**2
        dis_x = circ_dis_L(x, z, L)
        fp = np.exp(-dis_x**2 / (4 * a_p ** 2))
        log_likelihood_place = -np.sum((Ip - fp)**2) / sigma_p**2
        log_posterior = log_likelihood_grid + log_prior + log_likelihood_place
        # log_posterior = log_likelihood_place

        return log_posterior

    Energy = -calculate_posterior(z, phi, Ig, Ip)
    return Energy

def MAP_decoding(sigma_g, sigma_phi, sigma_p, Ig, Ip, candidate_num):
    Energy = Get_energy(sigma_g, sigma_phi, sigma_p, Ig, Ip, candidate_num)
    z_candidate = np.linspace(0, L , candidate_num)
    global_minimum = np.argmin(Energy)
    z_decode = z_candidate[global_minimum]
    
    return z_decode


def GOP_decoding(sigma_g, sigma_phi, sigma_p, Ig, Ip, candidate_num, z_init, L):
    # 计算能量
    Energy = Get_energy(sigma_g, sigma_phi, sigma_p, Ig, Ip, candidate_num)
    z_candidate = np.linspace(0, L, candidate_num)
    # 找到局部最小值
    peaks, _ = find_peaks(-Energy)
    local_minima = peaks
    
    # 找到离 z_init 最近的局部最小值
    closest_local_minimum_idx = np.argmin(np.abs(z_candidate[local_minima] - z_init))
    closest_local_minimum = local_minima[closest_local_minimum_idx]
    z_decode = z_candidate[closest_local_minimum]
    
    return z_decode


