In [7]:
from pyqcu.gmg import *
from pyqcu import define, gauge, io, qcu, set, bistabcg
import cupy as cp
import numpy as np
import time
import matplotlib.pyplot as plt
np.Inf = np.inf

In [None]:
class WilsonCase:
    def __init__(self, params=None, argv=None, min_size=4, max_levels=10, seed=12138):
        if params is None:
            self.params = set.params.copy()
            self.params[define._LAT_X_] = 16*define._LAT_P_
            self.params[define._LAT_Y_] = 16
            self.params[define._LAT_Z_] = 4
            self.params[define._LAT_T_] = 4
            self.params[define._LAT_XYZT_] = self.params[define._LAT_X_] * \
                self.params[define._LAT_Y_]*self.params[define._LAT_Z_] * \
                self.params[define._LAT_T_]
            self.params[define._DATA_TYPE_] = define._LAT_C64_
            self.params[define._NODE_RANK_] = define.rank
            self.params[define._NODE_SIZE_] = define.size
        else:
            self.params = params.copy()
        if argv is None:
            self.argv = set.argv.copy()
            self.argv[define._MASS_] = 0.0
        else:
            self.argv = argv.copy()
        self.min_size = min_size
        self.max_levels = max_levels
        self.seed = seed
        np.random.seed(seed)
        cp.random.seed(seed)
        self.lat_dict = {
            'params_params': [],
            'U_params': [],
            'set_ptrs': set.set_ptrs.copy(),
            'kappa': 1 / (2 * self.argv[define._MASS_] + 8),
            'grid_params': []
        }

    def give_b(self, params):
        b = cp.ones(params[define._LAT_XYZT_]*define._LAT_SC_//define._LAT_P_,
                    dtype=define.dtype(params[define._DATA_TYPE_]))
        b = io.fermion2sctzyx(b, params)
        return b

    def dslash_eo(self, src, index):
        eo_params = self.lat_dict['params_params'][index]
        eo_params[define._PARITY_] = define._EVEN_
        eo_params[define._DAGGER_] = define._NO_USE_
        _src = cp.array(src.copy())
        _dest = cp.zeros_like(_src)
        _U = cp.array(self.lat_dict['U_params'][index])
        _set_ptrs = self.lat_dict['set_ptrs']
        qcu.applyWilsonDslashQcu(
            _dest, _src, _U, _set_ptrs, eo_params)
        return _dest.get()

    def dslash_oe(self, src, index):
        oe_params = self.lat_dict['params_params'][index]
        oe_params[define._PARITY_] = define._ODD_
        oe_params[define._DAGGER_] = define._NO_USE_
        _src = cp.array(src.copy())
        _dest = cp.zeros_like(_src)
        _U = cp.array(self.lat_dict['U_params'][index])
        _set_ptrs = self.lat_dict['set_ptrs']
        qcu.applyWilsonDslashQcu(
            _dest, _src, _U, _set_ptrs, oe_params)
        return _dest.get()

    def dslash(self, src, index):
        eo_params = self.lat_dict['params_params'][index]
        eo_params[define._PARITY_] = define._EVEN_
        eo_params[define._DAGGER_] = define._NO_USE_
        oe_params = self.lat_dict['params_params'][index]
        oe_params[define._PARITY_] = define._ODD_
        oe_params[define._DAGGER_] = define._NO_USE_
        _src = cp.array(src.copy())
        tmp0 = cp.zeros_like(_src)
        tmp1 = cp.zeros_like(_src)
        _dest = cp.zeros_like(_src)
        _U = cp.array(self.lat_dict['U_params'][index])
        _set_ptrs = self.lat_dict['set_ptrs']
        qcu.applyWilsonDslashQcu(
            tmp0, _src, _U, _set_ptrs, eo_params)
        qcu.applyWilsonDslashQcu(
            tmp1, tmp0, _U, _set_ptrs, oe_params)
        _dest = _src-self.lat_dict['kappa']**2*tmp1
        return _dest.get()

    def run(self):
        current_nx, current_ny, current_nz = self.params[define._LAT_X_]//define._LAT_P_, self.params[
            define._LAT_Y_], self.params[define._LAT_Z_]*self.params[define._LAT_T_]*define._LAT_SC_
        print(
            f"current_nx: {current_nx}, current_ny: {current_ny}, current_nz: {current_nz}")
        while min(current_nx, current_ny) >= self.min_size and len(self.lat_dict['grid_params']) < self.max_levels:
            self.lat_dict['grid_params'].append(
                (current_nx, current_ny, current_nz))
            print(
                f"  Level {len(self.lat_dict['grid_params'])-1}: {current_nx}x{current_ny}x{current_nz}")
            current_nx = max(2, current_nx // 2)
            current_ny = max(2, current_ny // 2)
        _params = self.params.copy()
        for i, (nx, ny, nz) in enumerate(self.lat_dict['grid_params']):
            _params[define._SET_INDEX_] = i
            _params[define._SET_PLAN_] = define._SET_PLAN1_
            _params[define._LAT_X_] = nx*define._LAT_P_
            _params[define._LAT_Y_] = ny
            _params[define._LAT_XYZT_] = _params[define._LAT_X_] * \
                _params[define._LAT_Y_]*_params[define._LAT_Z_] * \
                _params[define._LAT_T_]
            if i == 0:
                U = gauge.give_gauss_SU3(sigma=0.1, seed=self.seed,
                                         dtype=define.dtype(_params[define._DATA_TYPE_]), size=_params[define._LAT_XYZT_]*define._LAT_S_)
                U = io.dptzyxcc2ccdptzyx(
                    io.gauge2dptzyxcc(gauge=U, params=_params))
            else:
                _U = U.copy()
                U = cp.zeros((define._LAT_C_, define._LAT_C_, define._LAT_D_, define._LAT_P_, _params[define._LAT_T_],
                              _params[define._LAT_Z_], _params[define._LAT_Y_], _params[define._LAT_X_]//define._LAT_P_), dtype=U.dtype)
                for d in range(define._LAT_D_):
                    for p in range(define._LAT_P_):
                        for t in range(_params[define._LAT_T_]):
                            for z in range(_params[define._LAT_Z_]):
                                for y in range(_params[define._LAT_Y_]):
                                    for x in range(_params[define._LAT_X_]//define._LAT_P_):
                                        U[:, :, d, p, t, z, y, x] = (_U[:, :, d, p, t,
                                                                        z, y*2, x*2] @ _U[:, :, d, p, t, z, y*2, x*2+1])@(_U[:, :, d, p, t,
                                                                                                                             z, y*2+1, x*2] @ _U[:, :, d, p, t, z, y*2+1, x*2+1])
            self.lat_dict['U_params'].append(U.get())
            qcu.applyInitQcu(self.lat_dict['set_ptrs'], _params, self.argv)
            self.lat_dict['params_params'].append(_params.copy())


case = WilsonCase()
case.run()

current_nx: 16, current_ny: 16, current_nz: 192
  Level 0: 16x16x192
  Level 1: 8x8x192
  Level 2: 4x4x192


KeyboardInterrupt: 

: 

In [None]:
class LatticeWilsonOperator:
    def __init__(self, nx, ny, nz, dtype=define.dtype(case.params[define._DATA_TYPE_])):
        self.nx = nx
        self.ny = ny
        self.nz = nz
        self.dtype = dtype
        self.index = 0
        for i, (nx, ny, nz) in enumerate(case.lat_dict['grid_params']):
            if nx == self.nx and ny == self.ny and nz == self.nz:
                self.index = i

    def matvec(self, v):
        return case.dslash(src=v, index=self.index)

    def give_b(self):
        return case.give_b(case.lat_dict['params_params'][self.index])

In [None]:
for i, (nx, ny, nz) in enumerate(case.lat_dict['grid_params']):
    src = cp.ones((define._LAT_P_, define._LAT_S_, define._LAT_C_,
                  case.params[define._LAT_T_], case.params[define._LAT_Z_], ny, nx), dtype=define.dtype(case.params[define._DATA_TYPE_]))
    def dslash_eo(src):
        return case.dslash_eo(src, i).reshape(src.shape)
    def dslash_oe(src):
        return case.dslash_oe(src, i).reshape(src.shape)
    def cp_dslash(src):
        return cp.array(case.dslash(src.get(), i).reshape(src.shape))
    b_e = src[define._EVEN_].get()
    b_o = src[define._ODD_].get()
    b__o = np.zeros_like(b_o)
    tmp = np.zeros_like(b_o)
    # b__o=b_o+kappa*D_oe(b_e)
    tmp = dslash_oe(b_e)
    b__o = b_o+case.lat_dict['kappa']*tmp
    # Dslash(x_o)=b__o
    x_o = bistabcg.slover(
        b=cp.array(b__o.flatten()), matvec=cp_dslash, tol=case.argv[define._TOL_], max_iter=case.params[define._MAX_ITER_]).reshape(b__o.shape).get()
    # x_e  =b_e+kappa*D_eo(x_o)
    tmp = dslash_eo(x_o)
    x_e = b_e + case.lat_dict['kappa']*tmp
    # dest = cp.zeros_like(src)
    # print(f"dest.shape={dest.shape}")
    # qcu.applyWilsonBistabCgQcu(dest, src,
    #                            cp.array(case.lat_dict['U_params'][i]), case.lat_dict['set_ptrs'], case.lat_dict['params_params'][i])
    # print(np.linalg.norm(dest[define._EVEN_].get()-x_e)/np.linalg.norm(x_e))
    # print(np.linalg.norm(dest[define._ODD_].get()-x_o)/np.linalg.norm(x_o))
    _b_e = x_e-case.lat_dict['kappa']*case.dslash_eo(x_o, i)
    _b_o = x_o-case.lat_dict['kappa']*case.dslash_oe(x_e, i)
    print(np.linalg.norm(_b_e-b_e)/np.linalg.norm(b_e))
    print(np.linalg.norm(_b_o-b_o)/np.linalg.norm(b_o))

multi-gpu wilson dslash total time: (without malloc free memcpy) :0.001946139 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000191979 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000145042 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.001138537 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000094648 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000134519 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000139190 sec
Iteration 0: Residual = 1.920730e+05, Time = 0.011031 s
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000447372 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000561817 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000144906 sec
multi-gpu wilson dslash total time: (without malloc free memcpy) :0.000109731 sec
Iteration 1: Residual = 9.866877e+04, Time

In [None]:
class LatticeGeometricMultigrid(GeometricMultigrid):
    # def restrict(self, u_fine, nx_fine, ny_fine, nz_fine):
    #     print(
    #         f"restrict nx_fine={nx_fine}, ny_fine={ny_fine}, nz_fine={nz_fine}")
    #     if nx_fine < 2 or ny_fine < 2:
    #         return u_fine
    #     nx_coarse = max(2, nx_fine // 2)
    #     ny_coarse = max(2, ny_fine // 2)
    #     nz_coarse = nz_fine
    #     u_fine_3d = u_fine.reshape((nz_fine, ny_fine, nx_fine))
    #     u_coarse_3d = np.zeros(
    #         (nz_coarse, ny_coarse, nx_coarse), dtype=self.dtype)
    #     for k in range(nz_fine):
    #         u_fine_slice = u_fine_3d[k, :, :]
    #         u_coarse_slice = np.zeros((ny_coarse, nx_coarse), dtype=self.dtype)
    #         for i in range(ny_coarse):
    #             for j in range(nx_coarse):
    #                 ii, jj = 2*i, 2*j
    #                 weight_sum = 0
    #                 value_sum = 0
    #                 for di in [-1, 0, 1]:
    #                     for dj in [-1, 0, 1]:
    #                         ni, nj = ii + di, jj + dj
    #                         if 0 <= ni < ny_fine and 0 <= nj < nx_fine:
    #                             if di == 0 and dj == 0:
    #                                 weight = 1/4
    #                             elif di == 0 or dj == 0:
    #                                 weight = 1/8
    #                             else:
    #                                 weight = 1/16
    #                             weight_sum += weight
    #                             value_sum += weight * u_fine_slice[ni, nj]
    #                 u_coarse_slice[i, j] = value_sum / \
    #                     weight_sum if weight_sum > 0 else 0
    #         u_coarse_3d[k, :, :] = u_coarse_slice
    #     # params_fine = None
    #     # params_coarse = None
    #     # for i, (nx, ny, nz) in enumerate(case.lat_dict['grid_params']):
    #     #     if nx == nx_fine and ny == ny_fine and nz == nz_fine:
    #     #         params_fine = case.lat_dict['params_params'][i]
    #     #     if nx == nx_coarse and ny == ny_coarse and nz == nz_coarse:
    #     #         params_coarse = case.lat_dict['params_params'][i]
    #     # _u_fine = u_fine.reshape((define._LAT_S_, define._LAT_C_,  params_fine[define._LAT_T_],
    #     #                           params_fine[define._LAT_Z_], params_fine[define._LAT_Y_], params_fine[define._LAT_X_]//define._LAT_P_))
    #     # _u_coarse = np.zeros((define._LAT_S_, define._LAT_C_, params_coarse[define._LAT_T_],
    #     #                       params_coarse[define._LAT_Z_], params_coarse[define._LAT_Y_], params_coarse[define._LAT_X_]//define._LAT_P_), dtype=u_fine.dtype)
    #     # for t in range(params_coarse[define._LAT_T_]):
    #     #     for z in range(params_coarse[define._LAT_Z_]):
    #     #         for y in range(params_coarse[define._LAT_Y_]):
    #     #             for x in range(params_coarse[define._LAT_X_]//define._LAT_P_):
    #     #                 _u_coarse[:, :, t,
    #     #                           z, y, x] = (_u_fine[:, :, t,
    #     #                                               z, y*2, x*2] + _u_fine[:, :,  t, z, y*2, x*2+1])+(_u_fine[:, :, t,
    #     #                                                                                                         z, y*2+1, x*2] + _u_fine[:, :, t, z, y*2+1, x*2+1])
    #     # return _u_coarse.flatten()
    #     return u_coarse_3d.flatten()

    # def prolongate(self, u_coarse, nx_fine, ny_fine, nz_fine):
    #     print(
    #         f"prolongate nx_fine={nx_fine}, ny_fine={ny_fine}, nz_fine={nz_fine}")
    #     nx_coarse = nx_fine // 2
    #     ny_coarse = ny_fine // 2
    #     nz_coarse = nz_fine
    #     u_coarse_3d = u_coarse.reshape((nz_coarse, ny_coarse, nx_coarse))
    #     u_fine_3d = np.zeros((nz_fine, ny_fine, nx_fine), dtype=self.dtype)
    #     for k in range(nz_fine):
    #         u_coarse_slice = u_coarse_3d[k, :, :]
    #         u_fine_slice = np.zeros((ny_fine, nx_fine), dtype=self.dtype)
    #         for i in range(ny_fine):
    #             for j in range(nx_fine):
    #                 i_c = i / 2.0
    #                 j_c = j / 2.0
    #                 i0, j0 = int(i_c), int(j_c)
    #                 i1 = min(i0 + 1, ny_coarse - 1)
    #                 j1 = min(j0 + 1, nx_coarse - 1)
    #                 wx = i_c - i0
    #                 wy = j_c - j0
    #                 u_fine_slice[i, j] = (1 - wx) * (1 - wy) * u_coarse_slice[i0, j0] + \
    #                                      (1 - wx) * wy * u_coarse_slice[i0, j1] + \
    #                     wx * (1 - wy) * u_coarse_slice[i1, j0] + \
    #                     wx * wy * u_coarse_slice[i1, j1]
    #         u_fine_3d[k, :, :] = u_fine_slice
    #     # params_fine = None
    #     # params_coarse = None
    #     # for i, (nx, ny, nz) in enumerate(case.lat_dict['grid_params']):
    #     #     if nx == nx_fine and ny == ny_fine and nz == nz_fine:
    #     #         params_fine = case.lat_dict['params_params'][i]
    #     #     if nx == nx_coarse and ny == ny_coarse and nz == nz_coarse:
    #     #         params_coarse = case.lat_dict['params_params'][i]
    #     # _u_coarse = u_coarse.reshape((define._LAT_S_, define._LAT_C_, params_coarse[define._LAT_T_],
    #     #                               params_coarse[define._LAT_Z_], params_coarse[define._LAT_Y_], params_coarse[define._LAT_X_]//define._LAT_P_))
    #     # _u_fine = np.zeros((define._LAT_S_, define._LAT_C_,  params_fine[define._LAT_T_],
    #     #                     params_fine[define._LAT_Z_], params_fine[define._LAT_Y_], params_fine[define._LAT_X_]//define._LAT_P_), dtype=u_coarse.dtype)
    #     # for t in range(params_fine[define._LAT_T_]):
    #     #     for z in range(params_fine[define._LAT_Z_]):
    #     #         for y in range(params_fine[define._LAT_Y_]):
    #     #             for x in range(params_fine[define._LAT_X_]//define._LAT_P_):
    #     #                 _u_fine[:, :, t,
    #     #                         z, y, x] = _u_coarse[:, :, t, z, y//2, x//2]
    #     # return _u_fine.flatten()
    #     return u_fine_3d.flatten()
    pass

In [None]:
solver = LatticeGeometricMultigrid(nx=case.params[define._LAT_X_]//define._LAT_P_, ny=case.params[define._LAT_Y_], nz=case.params[define._LAT_Z_]
                                   * case.params[define._LAT_T_]*define._LAT_SC_, dtype=define.dtype(case.params[define._DATA_TYPE_]), tolerance=case.argv[define._TOL_], min_size=case.min_size, max_levels=case.max_levels, max_iterations=case.params[define._MAX_ITER_])
solution = solver.solve()
solver.verify_solution(solution)
print(f"收敛迭代次数: {len(solver.convergence_history)}")
print(f"最终残差: {solver.convergence_history[-1]:.2e}")
plt.title(
    f'Adaptive Multigrid Complex Solution Results', fontsize=16)
plt.semilogy(range(1, len(solver.convergence_history) + 1),
             solver.convergence_history, 'b-o', markersize=4)
plt.tight_layout()
solve_time_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
plt.savefig(
    f"Adaptive_Multigrid_Complex_Solution_Results_{solve_time_str}.png", dpi=300)
print("所有测试完成!")
print(f"\n{'='*80}")


开始自适应多重网格复数求解

构建网格层次结构:
  Level 0: 16x16x192
  Level 1: 8x8x192
  Level 2: 4x4x192
总共 3 层网格

构建各层系统算子:
Level 0 (16x16x192):
Level 1 (8x8x192):
Level 2 (4x4x192):

开始多重网格迭代:
------------------------------

迭代 1:
V-循环 level 0, 当前层索引: 2, 网格大小: 16x16x192
    前光滑前残差范数: 1.1809e+02
    前光滑...
    前光滑后残差范数: 3.9596e+00
V-循环 level 1, 当前层索引: 1, 网格大小: 8x8x192
    前光滑前残差范数: 1.8889e+00
    前光滑...
    前光滑后残差范数: 1.3632e+00
V-循环 level 2, 当前层索引: 0, 网格大小: 4x4x192
    前残差范数: 5.6166e-01
    最粗网格直接求解...
    残差范数: 4.5759e-04
    后光滑前残差范数: 8.7157e-01
    后光滑...
    后光滑后残差范数: 3.9429e-01
    后光滑前残差范数: 2.3186e+00
    后光滑...
    后光滑后残差范数: 1.6540e+00
  迭代 1 完成，残差范数: 1.6540e+00

迭代 2:
V-循环 level 0, 当前层索引: 2, 网格大小: 16x16x192
    前光滑前残差范数: 1.6540e+00
    前光滑...
    前光滑后残差范数: 1.2712e+00
V-循环 level 1, 当前层索引: 1, 网格大小: 8x8x192
    前光滑前残差范数: 4.6474e-01
    前光滑...
    前光滑后残差范数: 4.3084e-01
V-循环 level 2, 当前层索引: 0, 网格大小: 4x4x192
    前残差范数: 1.6448e-01
    最粗网格直接求解...
    残差范数: 1.1146e-04
    后光滑前残差范数: 2.6005e-01
    后光滑...
 

KeyboardInterrupt: 