In [1]:
# init
from pyquda.utils import gauge_utils
from pyquda.field import LatticeFermion
from pyquda import core, pyqcu, mpi
import os
import sys
from time import perf_counter
import cupy as cp
test_dir = os.path.dirname(os.path.abspath("./"))
sys.path.insert(0, os.path.join(test_dir, ".."))
os.environ["QUDA_RESOURCE_PATH"] = ".cache"
latt_size = [32, 32, 32, 64]
grid_size = [1, 1, 1, 1]
Lx, Ly, Lz, Lt = latt_size
Nd, Ns, Nc = 4, 4, 3
Gx, Gy, Gz, Gt = grid_size
latt_size = [Lx//Gx, Ly//Gy, Lz//Gz, Lt//Gt]
Lx, Ly, Lz, Lt = latt_size
Vol = Lx * Ly * Lz * Lt
mpi.init(grid_size)
latt_tmp0 = LatticeFermion(latt_size, cp.zeros(
    (Lt, Lz, Ly, Lx, Ns, Nc), cp.complex128))
latt_tmp1 = LatticeFermion(latt_size, cp.zeros(
    (Lt, Lz, Ly, Lx, Ns, Nc), cp.complex128))
latt_shape = (Lt, Lz, Ly, Lx//2, Ns, Nc)
param = pyqcu.QcuParam()
param.lattice_size = latt_size
_dslash = core.getDslash(latt_size, -3.5, 0, 0, anti_periodic_t=False)
kappa = 0.125
U = gauge_utils.gaussGauge(latt_size, 0)
_dslash.loadGauge(U)

Disabling GPU-Direct RDMA access
Enabling peer-to-peer copy engine and direct load/store access
QUDA 1.1.0 (git 1.1.0--sm_80)
CUDA Driver version = 12040
CUDA Runtime version = 12030
Found device 0: NVIDIA GeForce RTX 4060 Laptop GPU
 -- This might result in a lower performance. Please consider adjusting QUDA_GPU_ARCH when running cmake.

Using device 0: NVIDIA GeForce RTX 4060 Laptop GPU
Loaded 20 sets of cached parameters from .cache/tunecache.tsv
Loaded 20 sets of cached parameters from .cache/tunecache.tsv
cublasCreated successfully
Creating Gaussian distrbuted Lie group field with sigma = 1.000000e-01


In [2]:
# give ans first
ans_e = cp.random.random(latt_shape) + 1j * \
    cp.random.random(latt_shape)  # ans_e
ans_o = cp.random.random(latt_shape) + 1j * \
    cp.random.random(latt_shape)  # ans_o

In [3]:
# give x_o, b__o, r, r_tilde, p, v, s, t
x_o = cp.random.random(latt_shape) + 1j * cp.random.random(latt_shape)  # x_o
b_e = cp.zeros(latt_shape, cp.complex128)
b_o = cp.zeros(latt_shape, cp.complex128)
b__o = cp.zeros(latt_shape, cp.complex128)
r = cp.zeros(latt_shape, cp.complex128)
r_tilde = cp.zeros(latt_shape, cp.complex128)
p = cp.zeros(latt_shape, cp.complex128)
s = cp.zeros(latt_shape, cp.complex128)
v = cp.zeros(latt_shape, cp.complex128)
t = cp.zeros(latt_shape, cp.complex128)
# give r_norm2, MAX_ITER, TOL, rho_prev, rho, alpha, omega, beta, kappa
r_norm2 = 0
MAX_ITER = 1e3
TOL = 1e-16
rho_prev = 1
rho = 0
alpha = 1
omega = 1
beta = 0
kappa = 0.125

In [4]:
# define dslash
def dslash_eo(src_o):
    latt_tmp1.data[1, :] = src_o
    pyqcu.dslashQcu(latt_tmp0.even_ptr, latt_tmp1.odd_ptr,
                    U.data_ptr, param, 0)  # D_eo
    return latt_tmp0.data[0, :]


def dslash_oe(src_e):
    latt_tmp1.data[0, :] = src_e
    pyqcu.dslashQcu(latt_tmp0.odd_ptr, latt_tmp1.even_ptr,
                    U.data_ptr, param, 1)  # D_oe
    return latt_tmp0.data[1, :]


def dslash(src_o):
    return src_o-kappa**2*dslash_oe(dslash_eo(src_o))

In [5]:
# give b'_o(b__0)
b__o = (ans_o-kappa*dslash_oe(ans_e))+kappa*dslash_oe((ans_e-kappa*dslash_eo(ans_o)))

In [6]:
def dot(a, b):
    cp.cuda.runtime.deviceSynchronize()
    return cp.inner(a.flatten().conjugate(), b.flatten())


def diff(a, b):
    cp.cuda.runtime.deviceSynchronize()
    return cp.linalg.norm(a - b) / cp.linalg.norm(b)

In [7]:
# bistabcg
r = b__o - dslash(x_o)
r_tilde = r

In [8]:
t0 = perf_counter()
for i in range(1, int(MAX_ITER)):
    # print("## rho:", rho)
    rho = dot(r_tilde, r)
    # print("## beta:", beta)
    beta = (rho / rho_prev) * (alpha / omega)
    p = r + (p - v * omega) * beta
    # v = A * p
    v = dslash(p)
    # print("## alpha:", alpha)
    alpha = rho / dot(r_tilde, v)
    s = r - v * alpha
    # t = A * s
    t = dslash(s)
    # print("## omega:", omega)
    omega = dot(t, s) / dot(t, t)
    x_o = x_o + p * alpha + s * omega
    r = s - t * omega
    r_norm2 = cp.linalg.norm(r)
    print("##{}# r_norm2:{}".format(i, r_norm2))
    # break
    if (r_norm2 < TOL or i == MAX_ITER - 1):
        print("### turns:", i)
        break
    rho_prev = rho
t1 = perf_counter()

##1# r_norm2:211.05940954198724
##2# r_norm2:91.4254398704544
##3# r_norm2:210.11004181163827
##4# r_norm2:116.24409693026254
##5# r_norm2:23.83248463585693
##6# r_norm2:10.402463478131137
##7# r_norm2:7.236147910376961
##8# r_norm2:5.114937575584901
##9# r_norm2:5.313254764479292
##10# r_norm2:2.341656868485409
##11# r_norm2:1.7753773290993604
##12# r_norm2:1.630838448517683
##13# r_norm2:1.2453114064862467
##14# r_norm2:1.1025586132816554
##15# r_norm2:0.6223294561428675
##16# r_norm2:1.060898691427699
##17# r_norm2:0.7197237889696892
##18# r_norm2:0.3142358546768601
##19# r_norm2:0.49171618715660187
##20# r_norm2:0.26367144500041234
##21# r_norm2:0.44904127902562063
##22# r_norm2:0.2760277927469005
##23# r_norm2:0.16345114835612232
##24# r_norm2:0.15557250701981146
##25# r_norm2:0.13707828005828682
##26# r_norm2:0.12191543367243113
##27# r_norm2:0.11283290069233586
##28# r_norm2:0.11876992609560967
##29# r_norm2:0.11321041678925006
##30# r_norm2:0.08041062343573367
##31# r_norm2:0.0

In [9]:
print("### ans_o = ", ans_o[0, 0, 0, 0, :, :])
print("### x_o = ", x_o[0, 0, 0, 0, :, :])
print('## difference: ', diff(x_o, ans_o))
print(f'Quda dslash: {t1 - t0} sec')

### ans_o =  [[0.29616538+9.39505944e-01j 0.09386828+6.12764439e-01j
  0.9447394 +4.18722549e-01j]
 [0.58763599+6.02555499e-04j 0.31871455+1.99612606e-01j
  0.89405829+6.90060228e-01j]
 [0.67280703+5.80694740e-01j 0.03719632+7.47318222e-01j
  0.70112218+7.38992539e-01j]
 [0.42383408+7.11157599e-01j 0.5103088 +4.21558372e-01j
  0.82605577+6.71143171e-01j]]
### x_o =  [[0.29616538+9.39505944e-01j 0.09386828+6.12764439e-01j
  0.9447394 +4.18722549e-01j]
 [0.58763599+6.02555499e-04j 0.31871455+1.99612606e-01j
  0.89405829+6.90060228e-01j]
 [0.67280703+5.80694740e-01j 0.03719632+7.47318222e-01j
  0.70112218+7.38992539e-01j]
 [0.42383408+7.11157599e-01j 0.5103088 +4.21558372e-01j
  0.82605577+6.71143171e-01j]]
## difference:  9.820552614762689e-16
Quda dslash: 367.6363070970001 sec


In [10]:
x_e = ans_e-kappa*dslash_eo(ans_o)+kappa * \
    dslash_eo(x_o)  # equal to x_o == ans_o
print(diff(x_e, ans_e))

5.528913503654499e-16
