In [None]:
# init
from pyquda.utils import gauge_utils
from pyquda.field import LatticeFermion
from pyquda.enum_quda import QudaParity
from pyquda import init, core, quda, pyqcu, mpi, pointer
import os
import sys
from time import perf_counter
import cupy as cp
test_dir = os.path.dirname(os.path.abspath("./"))
sys.path.insert(0, os.path.join(test_dir, ".."))
os.environ["QUDA_RESOURCE_PATH"] = ".cache"
latt_size = [16, 32, 32, 64]
grid_size = [1, 1, 1, 1]
Lx, Ly, Lz, Lt = latt_size
Nd, Ns, Nc = 4, 4, 3
Gx, Gy, Gz, Gt = grid_size
latt_size = [Lx//Gx, Ly//Gy, Lz//Gz, Lt//Gt]
_latt_size = [Lx//Gx//2, Ly//Gy, Lz//Gz, Lt//Gt]
Lx, Ly, Lz, Lt = latt_size
Vol = Lx * Ly * Lz * Lt
mpi.init(grid_size)
latt_shape = (Lt, Lz, Ly, Lx//2, Ns, Nc)
param = pyqcu.QcuParam()
param.lattice_size = latt_size
dslash = core.getDslash(latt_size, -3.5, 0, 0, anti_periodic_t=False)
kappa = 0.125
U = gauge_utils.gaussGauge(latt_size, 0)
dslash.loadGauge(U)

In [None]:
# give ans first
ans_e = cp.random.random(latt_shape) + 1j * \
    cp.random.random(latt_shape)  # ans_e
ans_o = cp.random.random(latt_shape) + 1j * \
    cp.random.random(latt_shape)  # ans_o
print("## ans_o = ", ans_o[0, 0, 0, 0, 0, 0])

In [None]:
# give x_o, b__o, r, r_tilde, p, v, s, t, latt_tmp0, latt_tmp1
x_o = cp.random.random(latt_shape) + 1j * cp.random.random(latt_shape)  # x_o
b_e = cp.zeros(latt_shape, cp.complex128)
b_o = cp.zeros(latt_shape, cp.complex128)
b__o = cp.zeros(latt_shape, cp.complex128)
r = cp.zeros(latt_shape, cp.complex128)
r_tilde = cp.zeros(latt_shape, cp.complex128)
p = cp.zeros(latt_shape, cp.complex128)
s = cp.zeros(latt_shape, cp.complex128)
v = cp.zeros(latt_shape, cp.complex128)
t = cp.zeros(latt_shape, cp.complex128)
latt_tmp0 = cp.zeros(latt_shape, cp.complex128)
latt_tmp1 = cp.zeros(latt_shape, cp.complex128)
zero = cp.zeros(latt_shape, cp.complex128)
# give r_norm2, MAX_ITER, TOL, rho_prev, rho, alpha, omega, beta, tmp0, tmp1, kappa
r_norm2 = 0
MAX_ITER = 1e2
TOL = 1e-6
rho_prev = 1
rho = 0
alpha = 1
omega = 1
beta = 0
tmp0 = 0
tmp1 = 0
kappa = 0.125

In [None]:
# define dslash
def dslash(src_o, dest_o):
    latt_tmp0 = zero
    latt_tmp1 = zero
    _latt_tmp0 = LatticeFermion(_latt_size, latt_tmp0)
    _latt_tmp1 = LatticeFermion(_latt_size, latt_tmp1)
    _src_o = LatticeFermion(_latt_size, src_o)
    pyqcu.dslashQcu(_latt_tmp0.even_ptr, _src_o.even_ptr,
                    U.data_ptr, param, 0)  # D_eo
    pyqcu.dslashQcu(_latt_tmp1.even_ptr, _latt_tmp0.even_ptr,
                    U.data_ptr, param, 1)  # D_oe
    return src_o-kappa**2*latt_tmp1

In [None]:
# give b'_o(b__0)
_latt_tmp0 = LatticeFermion(_latt_size, latt_tmp0)
_latt_tmp1 = LatticeFermion(_latt_size, latt_tmp1)
_ans_e = LatticeFermion(_latt_size, ans_e)
_ans_o = LatticeFermion(_latt_size, ans_o)
_b_e = LatticeFermion(_latt_size, b_e)
_b_o = LatticeFermion(_latt_size, b_o)
latt_tmp0 = zero
pyqcu.dslashQcu(_latt_tmp0.even_ptr, _ans_o.even_ptr,
                U.data_ptr, param, 0)  # give D_eo ans_o
b_e = ans_e-kappa*latt_tmp0
latt_tmp1 = zero
pyqcu.dslashQcu(_latt_tmp1.even_ptr, _ans_e.even_ptr,
                U.data_ptr, param, 1)  # give D_oe ans_e
b_o = ans_o-kappa*latt_tmp1
latt_tmp1 = zero
pyqcu.dslashQcu(_latt_tmp1.even_ptr, _b_e.even_ptr,
                U.data_ptr, param, 1)  # give D_oe b_e
b__o = b_o+kappa*latt_tmp1

In [None]:
print(b__o)

In [None]:
def dot(a, b):
    return cp.inner(a.flatten().conjugate(), b.flatten())

In [12]:
# bistabcg
dslash(x_o, r)
r = b__o - r
r_tilde = r
for i in range(1, int(MAX_ITER)):
    print("## rho:", rho)
    rho = dot(r_tilde, r)
    print("## beta:", beta)
    beta = (rho / rho_prev) * (alpha / omega)
    p = r + (p - v * omega) * beta
    # v = A * p
    dslash(p, v)
    print("## alpha:", alpha)
    alpha = rho / dot(r_tilde, v)
    s = r - v * alpha
    # t = A * s
    dslash(s, t)
    print("## omega:", omega)
    omega = dot(t, s) / dot(t, t)
    x_o = x_o + p * alpha + s * omega
    r = s - t * omega
    print("## r_norm2:", r_norm2)
    r_norm2 = dot(r, r)
    # break
    if (r_norm2 < TOL or i == MAX_ITER - 1):
        print("## turns:", i)
        break
    rho_prev = rho

print('## difference: ', cp.linalg.norm(x_o - ans_o) / cp.linalg.norm(ans_o))

## alpha: 1
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nanj)
## alpha: (nan+nanj)
## r_norm2: (nan+nan

KeyboardInterrupt: 

In [None]:
print(r.shape)

In [None]:
print(r_tilde.shape)