In [2]:
from time import perf_counter
import tilelang
import torch
from argparse import Namespace
from pyqcu import lattice, solver, dslash, _torch, tools
import mpi4py.MPI as MPI
import pyqcu
kappa = 0.125
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.complex64
lat_size = [32, 32, 32, 32]
path = pyqcu.__file__.replace('pyqcu/__init__.py', 'examples/data/')
refer_U = tools.hdf5oooxyzt2gridoooxyzt(
    file_name=path+'refer.wilson.U.L32K0_125.ccdxyzt.c64.h5', lat_size=lat_size, device=device, verbose=True)
refer_src = tools.hdf5oooxyzt2gridoooxyzt(
    file_name=path+'refer.wilson.src.L32K0_125.scxyzt.c64.h5', lat_size=lat_size, device=device, verbose=True)
refer_dest = tools.hdf5oooxyzt2gridoooxyzt(
    file_name=path+'refer.wilson.dest.L32K0_125.scxyzt.c64.h5', lat_size=lat_size, device=device, verbose=True)
refer_clover_term = torch.zeros(
    size=[4, 3, 4, 3]+list(refer_src.shape)[2:], dtype=dtype, device=device)
operator = dslash.operator(
    U=refer_U, kappa=kappa, clover_term=refer_clover_term, verbose=True)
time_start = perf_counter()
dest = operator.matvec(src=refer_src)
# dest = dslash.give_wilson(
#     src=refer_src, U=refer_U, kappa=kappa, with_I=True,  verbose=True)
time_end = perf_counter()
is_su3 = lattice.check_su3(refer_U, tol=1e-6, verbose=True)
diff = tools.norm(dest - refer_dest)/tools.norm(refer_dest)

PYQCU::TOOLS::IO:
 rank 0: Grid Lat X: 32, Y: 32, Z: 32, T: 32
PYQCU::TOOLS::IO:
 rank 0: Grid Index X: 0, Y: 0, Z: 0, T: 0
PYQCU::TOOLS::IO:
 rank 0: Dest Shape: (3, 3, 4, 32, 32, 32, 32)
PYQCU::TOOLS::IO:
 rank 0: All Dest Shape: (3, 3, 4, 32, 32, 32, 32)
PYQCU::TOOLS::IO:
 rank 0: Data is loaded from /root/PyQCU/examples/data/refer.wilson.U.L32K0_125.ccdxyzt.c64.h5 (MPI mode)
PYQCU::TOOLS::IO:
 rank 0: Grid Lat X: 32, Y: 32, Z: 32, T: 32
PYQCU::TOOLS::IO:
 rank 0: Grid Index X: 0, Y: 0, Z: 0, T: 0
PYQCU::TOOLS::IO:
 rank 0: Dest Shape: (4, 3, 32, 32, 32, 32)
PYQCU::TOOLS::IO:
 rank 0: All Dest Shape: (4, 3, 32, 32, 32, 32)
PYQCU::TOOLS::IO:
 rank 0: Data is loaded from /root/PyQCU/examples/data/refer.wilson.src.L32K0_125.scxyzt.c64.h5 (MPI mode)
PYQCU::TOOLS::IO:
 rank 0: Grid Lat X: 32, Y: 32, Z: 32, T: 32
PYQCU::TOOLS::IO:
 rank 0: Grid Index X: 0, Y: 0, Z: 0, T: 0
PYQCU::TOOLS::IO:
 rank 0: Dest Shape: (4, 3, 32, 32, 32, 32)
PYQCU::TOOLS::IO:
 rank 0: All Dest Shape: (4, 3, 32, 3

In [4]:
print(
    f"PYQCU::TESTING::DSLASH::WILSON::REFER_U:\n Gauge field SU(3) check: {is_su3}")
print(f"PYQCU::TESTING::DSLASH::WILSON::REFER_U:\n {tools.norm(refer_U)}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON::REFER_U:\n {refer_U.flatten()[:12]}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON::REFER_SRC:\n {tools.norm(refer_src)}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON::REFER_SRC:\n {refer_src.flatten()[:12]}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON::REFER_DEST:\n {tools.norm(refer_dest)}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON::REFER_DEST:\n {refer_dest.flatten()[:12]}")
print(f"PYQCU::TESTING::DSLASH::WILSON::DEST:\n {tools.norm(dest)}")
print(f"PYQCU::TESTING::DSLASH::WILSON::DEST:\n {dest.flatten()[:12]}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON:\n Time cost: {time_end-time_start}")
print(
    f"PYQCU::TESTING::DSLASH::WILSON:\n Difference between computed and reference dslash: {diff}")

PYQCU::TESTING::DSLASH::WILSON::REFER_U:
 Gauge field SU(3) check: True
PYQCU::TESTING::DSLASH::WILSON::REFER_U:
 3547.239990234375
PYQCU::TESTING::DSLASH::WILSON::REFER_U:
 tensor([0.9889+0.0213j, 0.9954-0.0481j, 0.9717+0.0749j, 0.9889+0.0248j,
        0.9872-0.0946j, 0.9790+0.0879j, 0.9903-0.0562j, 0.9928-0.0216j,
        0.9862-0.1220j, 0.9860-0.0216j, 0.9935-0.0203j, 0.9943+0.0137j],
       device='cuda:0')
PYQCU::TESTING::DSLASH::WILSON::REFER_SRC:
 14627.9716796875
PYQCU::TESTING::DSLASH::WILSON::REFER_SRC:
 tensor([ 4.7165-5.4729j,  0.3475-0.1108j, -3.7608-2.9677j, -1.8115+0.8868j,
         0.0147+8.0856j, -0.1822+1.0480j, -0.1321+0.2830j,  0.5528+1.4979j,
         0.0912+1.8120j, -0.0501-1.6487j,  1.9705-0.2234j,  0.3613-0.0824j],
       device='cuda:0')
PYQCU::TESTING::DSLASH::WILSON::REFER_DEST:
 15330.1904296875
PYQCU::TESTING::DSLASH::WILSON::REFER_DEST:
 tensor([ 4.1269-4.7888j, -2.2271+2.1175j, -3.2907-2.5967j, -5.4426-0.6956j,
         0.0129+7.0749j, -0.5508-4.3034j, -0

In [5]:
_dest =dslash.give_wilson(src=refer_src, U=refer_U, kappa=kappa, with_I=True,  verbose=True)

PYQCU::DSLASH::WILSON:
 Applying Dirac operator...
PYQCU::DSLASH::WILSON:
 Source shape: torch.Size([4, 3, 32, 32, 32, 32])
PYQCU::DSLASH::WILSON:
 Gauge field shape: torch.Size([3, 3, 4, 32, 32, 32, 32])
PYQCU::DSLASH::WILSON:
 Source norm: 14627.9716796875
PYQCU::DSLASH::WILSON:
 Processing x (ward=-4)...
PYQCU::DSLASH::WILSON:
 Hopping term norm: 29255.943359375
PYQCU::DSLASH::WILSON:
 Processing y (ward=-3)...
PYQCU::DSLASH::WILSON:
 Hopping term norm: 29255.943359375
PYQCU::DSLASH::WILSON:
 Processing z (ward=-2)...
PYQCU::DSLASH::WILSON:
 Hopping term norm: 29255.943359375
PYQCU::DSLASH::WILSON:
 Processing t (ward=-1)...
PYQCU::DSLASH::WILSON:
 Hopping term norm: 29255.943359375
PYQCU::DSLASH::WILSON:
 Dirac operator application complete
PYQCU::DSLASH::WILSON:
 Dest norm: 15330.1904296875


In [7]:
print(torch.norm(refer_dest-_dest)/torch.norm(refer_dest))

tensor(7.5128e-08, device='cuda:0')


In [8]:
operator.matvec

<bound method operator.matvec of <pyqcu.dslash._operator.operator object at 0x7f7c74f66bc0>>

In [None]:
operator.