In [None]:
import tools as tl 
from jax import device_put
import os
from jax.config import config
config.update("jax_enable_x64", True)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# 从root文件读取数据

phif001, phif021, phif001MC, phif021MC, Kp, Km, Pip, Pim, KpMC, KmMC, PipMC,PimMC = tl.readroot("data/Fastor.root")

phif001 = device_put(phif001)
phif021 = device_put(phif021)
phif001MC =device_put(phif001MC)
phif021MC =device_put(phif021MC)
Kp = device_put(Kp)
Km = device_put(Km)
Pip = device_put(Pip)
Pim = device_put(Pim)
KpMC = device_put(KpMC)
KmMC = device_put(KmMC)
PipMC =device_put(PipMC)
PimMC =device_put(PimMC)

In [None]:
import numpy as onp
import jax.numpy as np
from jax import vmap
from functools import partial
import time
from jax import jit
from jax import grad
import os
from jax.config import config
config.update("jax_enable_x64", True)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# 计算 Pc + Pb 的不变质量
def invm_plus(Pb,Pc):
    Pbc = Pb + Pc
    _Pbc = Pbc * np.array([-1,-1,-1,1])
    return np.sum(Pbc * _Pbc,axis=1)

# 计算 Pbc 的不变质量
def invm(Pbc):
    _Pbc = Pbc * np.array([-1,-1,-1,1])
    return np.sum(Pbc * _Pbc,axis=1)

# 对bw_取绝对值
def _abs(bw_):
    conjbw = np.conj(bw_)
    return np.real(bw_*conjbw)

# briet-w 公式
def BW(m_,w_,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    return k/(m_*m_ - Sbc - m_*w_*1j)

# \rho * e^{\theta}
def phase(theta, rho):
#     return rho * np.exp(theta*1j)
    return rho * (np.cos(theta)+1j*np.sin(theta))

In [None]:
# 产生随机数作为假数据
size = 800000
Kp = onp.random.sample(size*4).reshape(size,4)
Km = onp.random.sample(size*4).reshape(size,4)
Pip = onp.random.sample(size*4).reshape(size,4)
Pim = onp.random.sample(size*4).reshape(size,4)
phif001 = onp.random.sample(size*2).reshape(size,2)
phif021 = onp.random.sample(size*2).reshape(size,2)

In [None]:
# 计算中间过程粒子的四动量
phi = invm_plus(Kp,Km)
f0 = invm_plus(Pip,Pim)
phif0 = np.asarray([phif001,phif021])

# 随便设置了拟合参数的初始值
# 可以通过设置拟合参数数组的长度来控制
phim = np.array([2.,1.,1.,1.])
phiw = np.array([1.,2.,1.,1.])
f0m = np.array([1.,1.,1.,3.])
f0w = np.array([1.,1.,1.,1.])
const = np.array([[2.,1.,1.,1.],[1.,1.,1.,1.]])
rho = np.array([1.,1.,2.,1.])
theta = np.array([1.,1.,1.,1.])

In [None]:
# bw 和 张量部分的组合并求导
# 使用复数

def BW(m_,w_,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    return k/(m_*m_ - Sbc - m_*w_*1j)

def phase(theta, rho):
#     return rho * np.exp(theta*1j)
    return rho * (np.cos(theta)+1j*np.sin(theta))

def BW_f0(phim,phiw,f0m,f0w,phi,f0):
    return vmap(partial(BW,Sbc=phi))(phim,phiw) * vmap(partial(BW,Sbc=f0))(f0m,f0w)

def phase_f0(theta_,rho_):
    result = vmap(phase)(theta_,rho_)
    return result

def test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0):
    ph = phase_f0(theta,rho)
    bw = BW_f0(phim,phiw,f0m,f0w,phi,f0)
    print(bw)
    _phif0 = np.einsum('ijk,il->ljk',phif0,const)
    _phif0 = np.einsum('ijk,i->ijk',_phif0,ph)
    _phif0 = np.einsum('ijk,ij->jk',_phif0,bw)
    _phif0 = np.real(np.sum(_abs(_phif0),axis=1))
    return -np.sum(np.log(_phif0))

print(test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))

m = (0,1,2,3,4,5,6)
grad_test_pw = jit(grad(test_pw,argnums=m))
%time onp.asarray(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time onp.asarray(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time onp.asarray(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time print(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))

In [None]:
# bw 和 张量部分的组合并求导 
# 不使用复数 

import dplex

def BW(m_,w_,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    l = Sbc.shape[0]
    temp = dplex.dconstruct(m_*m_ - Sbc,  -m_*w_*np.ones(l))
    return dplex.ddivide(k, temp)

def phase(theta, rho):
    return dplex.dconstruct(rho * np.cos(theta), rho * np.sin(theta))

def BW_f0(phim,phiw,f0m,f0w,phi,f0):
    a = np.moveaxis(vmap(partial(BW,Sbc=phi))(phim,phiw),1,0)
    b = np.moveaxis(vmap(partial(BW,Sbc=f0))(f0m,f0w),1,0)
    result = dplex.deinsum('ij,ij->ij',a,b)
    return result
    
def phase_f0(theta_,rho_):
    result = vmap(phase)(theta_,rho_)
    return result

def test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0):
    ph = np.moveaxis(phase_f0(theta,rho), 1, 0)
    bw = BW_f0(phim,phiw,f0m,f0w,phi,f0)
    print(bw.shape)
    print(ph.shape)
    _phif0 = dplex.dtomine(np.einsum('ijk,il->ljk',phif0,const))
    _phif0 = dplex.deinsum('ijk,i->ijk',_phif0,ph)
    _phif0 = dplex.deinsum('ijk,ij->jk',_phif0,bw)
    _phif0 = np.real(np.sum(dplex.dabs(_phif0),axis=1))
    return -np.sum(np.log(_phif0))

print(test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))

m = (0,1,2,3,4,5,6)
grad_test_pw = jit(grad(test_pw,argnums=m))
grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0)
%time onp.asarray(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time onp.asarray(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time print(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))

In [None]:
# bw 和 张量部分的组合并求导
# 拆公式的方法不使用复数

def BW_ph_real(m_,w_,rho,theta,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    down = (m_**2 - Sbc)**2 + (m_*w_)**2
    real_ = k * rho * ((m_**2-Sbc) * np.cos(theta) - m_*w_ *np.sin(theta)) / down
    return real_

def BW_ph_imag(m_,w_,rho,theta,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    down = (m_**2 - Sbc)**2 + (m_*w_)**2
    imag_ = k * rho * ((m_**2 - Sbc) * np.sin(theta) + m_*w_*np.cos(theta)) / down
    return imag_

def BW_real(m_,w_,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    down = (m_**2 - Sbc)**2 + (m_*w_)**2
    real_ = k * (m_**2 - Sbc) / down
    return real_

def BW_imag(m_,w_,Sbc):
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    down = (m_**2 - Sbc)**2 + (m_*w_)**2
    imag_ = k * m_ * w_ / down
    return imag_

def BW_f0_real(phim,phiw,f0m,f0w,rho,theta,phi,f0):
    return vmap(partial(BW_ph_real,Sbc=phi))(phim,phiw,rho,theta) * vmap(partial(BW_real,Sbc=f0))(f0m,f0w) - \
    vmap(partial(BW_ph_imag,Sbc=phi))(phim,phiw,rho,theta) * vmap(partial(BW_imag,Sbc=f0))(f0m,f0w)
    
def BW_f0_imag(phim,phiw,f0m,f0w,rho,theta,phi,f0):
    return vmap(partial(BW_ph_real,Sbc=phi))(phim,phiw,rho,theta) * vmap(partial(BW_imag,Sbc=f0))(f0m,f0w) + \
    vmap(partial(BW_ph_imag,Sbc=phi))(phim,phiw,rho,theta) * vmap(partial(BW_real,Sbc=f0))(f0m,f0w)

def real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0):
    bw_real = BW_f0_real(phim,phiw,f0m,f0w,rho,theta,phi,f0)
    bw_imag = BW_f0_imag(phim,phiw,f0m,f0w,rho,theta,phi,f0)
    _phif0 = np.einsum('ijk,il->ljk',phif0,const)
    real_phif0 = np.einsum('ijk,ij->jk',_phif0,bw_real)**2
    imag_phif0 = np.einsum('ijk,ij->jk',_phif0,bw_imag)**2
    _phif0 = np.sum(real_phif0+imag_phif0,axis=1)
    return -np.sum(np.log(_phif0))

print(real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))

m = (0,1,2,3,4,5,6)
grad_real_pw = jit(grad(real_pw,argnums=m))
# s = time.time()
# print(grad_real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
# e = time.time()
# print("time: ",e-s)
# s = time.time()
# print(grad_real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
# e = time.time()
# print("time: ",e-s)
# s = time.time()
# print(grad_real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
# e = time.time()
# print("time: ",e-s)
%time onp.asarray(grad_real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time onp.asarray(grad_real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
%time print(grad_real_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))