In [3]:
from pyqcu.gmg import *
from pyqcu import define, gauge, io, qcu, set, bistabcg
import cupy as cp
import numpy as np
import time
import matplotlib.pyplot as plt
np.Inf = np.inf


In [None]:
class WilsonCase:
    def __init__(self, params=None, argv=None, min_size=4, max_levels=10, seed=12138):
        if params is None:
            self.params = set.params.copy()
            self.params[define._LAT_X_] = 4
            self.params[define._LAT_Y_] = 32
            self.params[define._LAT_Z_] = 32
            self.params[define._LAT_T_] = 4
            self.params[define._LAT_XYZT_] = self.params[define._LAT_X_] * \
                self.params[define._LAT_Y_]*self.params[define._LAT_Z_] * \
                self.params[define._LAT_T_]
            self.params[define._DATA_TYPE_] = define._LAT_C64_
            self.params[define._NODE_RANK_] = define.rank
            self.params[define._NODE_SIZE_] = define.size
        else:
            self.params = params.copy()
        if argv is None:
            self.argv = set.argv.copy()
            self.argv[define._MASS_] = 0.0
        else:
            self.argv = argv.copy()
        self.min_size = min_size
        self.max_levels = max_levels
        self.seed = seed
        np.random.seed(seed)
        cp.random.seed(seed)
        self.lat_dict = {
            'params_params': [],
            'src_params': [],
            'dest_params': [],
            'U_params': [],
            'set_ptrs': set.set_ptrs.copy(),
            'kappa': 1 / (2 * self.argv[define._MASS_] + 8),
            'grid_params': []
        }

    def give_b(self, params):
        b = cp.ones(params[define._LAT_XYZT_]*define._LAT_SC_,
                    dtype=define.dtype(params[define._DATA_TYPE_]))
        b = io.fermion2psctzyx(b, params)
        return b

    def dslash_eo(self, src, index):
        eo_params = self.lat_dict['params_params'][index]
        eo_params[define._PARITY_] = define._EVEN_
        eo_params[define._DAGGER_] = define._NO_USE_
        _src = cp.array(src.copy())
        _dest = cp.zeros_like(_src)
        _U = cp.array(self.lat_dict['U_params'][index])
        _set_ptrs = self.lat_dict['set_ptrs']
        qcu.applyWilsonDslashQcu(
            _dest, _src, _U, _set_ptrs, eo_params)
        return _dest.get()

    def dslash_oe(self, src, index):
        oe_params = self.lat_dict['params_params'][index]
        oe_params[define._PARITY_] = define._ODD_
        oe_params[define._DAGGER_] = define._NO_USE_
        _src = cp.array(src.copy())
        _dest = cp.zeros_like(_src)
        _U = cp.array(self.lat_dict['U_params'][index])
        _set_ptrs = self.lat_dict['set_ptrs']
        qcu.applyWilsonDslashQcu(
            _dest, _src, _U, _set_ptrs, oe_params)
        return _dest.get()

    def dslash(self, src, index):
        eo_params = self.lat_dict['params_params'][index]
        eo_params[define._PARITY_] = define._EVEN_
        eo_params[define._DAGGER_] = define._NO_USE_
        oe_params = self.lat_dict['params_params'][index]
        oe_params[define._PARITY_] = define._ODD_
        oe_params[define._DAGGER_] = define._NO_USE_
        _src = cp.array(src.copy())
        tmp0 = cp.zeros_like(_src)
        tmp1 = cp.zeros_like(_src)
        _dest = cp.zeros_like(_src)
        _U = cp.array(self.lat_dict['U_params'][index])
        _set_ptrs = self.lat_dict['set_ptrs']
        qcu.applyWilsonDslashQcu(
            tmp0, _src, _U, _set_ptrs, eo_params)
        qcu.applyWilsonDslashQcu(
            tmp1, tmp0, _U, _set_ptrs, oe_params)
        _dest = _src-self.lat_dict['kappa']**2*tmp1
        return _dest.get()

    def run(self):
        grid_params = []
        current_nx, current_ny, current_nz = self.params[define._LAT_Y_], self.params[
            define._LAT_Z_], self.params[define._LAT_X_]*self.params[define._LAT_T_]*define._LAT_SC_
        print(
            f"current_nx: {current_nx}, current_ny: {current_ny}, current_nz: {current_nz}")
        while min(current_nx, current_ny) >= self.min_size and len(grid_params) < self.max_levels:
            grid_params.append((current_nx, current_ny, current_nz))
            print(
                f"  Level {len(grid_params)-1}: {current_nx}x{current_ny}x{current_nz}")
            current_nx = max(2, current_nx // 2)
            current_ny = max(2, current_ny // 2)
        self.lat_dict['grid_params'] = grid_params
        for i, (nx, ny, nz) in enumerate(grid_params):
            self.params[define._SET_INDEX_] = i
            self.params[define._SET_PLAN_] = define._SET_PLAN1_
            self.params[define._LAT_Y_] = nx
            self.params[define._LAT_Z_] = ny
            self.params[define._LAT_XYZT_] = self.params[define._LAT_X_] * \
                self.params[define._LAT_Y_]*self.params[define._LAT_Z_] * \
                self.params[define._LAT_T_]
            if i == 0:
                U = gauge.give_gauge(params=self.params)
            else:
                _U = U.copy()
                _shape = list(_U.shape)
                lat_x = _shape[-1]
                lat_y = _shape[-2]
                lat_z = _shape[-3]
                lat_t = _shape[-4]
                lat_p = define._LAT_P_
                lat_d = define._LAT_D_
                lat_c = define._LAT_C_
                lat_y //= 2
                __U = cp.zeros((lat_c, lat_c, lat_d, lat_p, lat_t,
                               lat_z, lat_y, lat_x), dtype=U.dtype)
                for d in range(lat_d):
                    for p in range(lat_p):
                        for t in range(lat_t):
                            for z in range(lat_z):
                                for y in range(lat_y):
                                    for x in range(lat_x):
                                        __U[:, :, d, p, t, z, y, x] = _U[:, :, d, p, t,
                                                                         z, y*2, x] @ _U[:, :, d, p, t, z, y*2+1, x]
                lat_z //= 2
                U = cp.zeros((lat_c, lat_c, lat_d, lat_p, lat_t,
                              lat_z, lat_y, lat_x), dtype=U.dtype)
                for d in range(lat_d):
                    for p in range(lat_p):
                        for t in range(lat_t):
                            for z in range(lat_z):
                                for y in range(lat_y):
                                    for x in range(lat_x):
                                        U[:, :, d, p, t, z, y, x] = __U[:, :, d, p, t,
                                                                        z*2, y, x] @ __U[:, :, d, p, t, z*2+1, y, x]
            self.lat_dict['U_params'].append(U.get())
            qcu.applyInitQcu(self.lat_dict['set_ptrs'], self.params, self.argv)
            self.lat_dict['params_params'].append(self.params.copy())
            src = self.give_b(self.params)
            self.lat_dict['src_params'].append(src.get())
            dest = cp.zeros_like(src)
            # qcu.applyWilsonBistabCgQcu(dest, src,
            #                            U, self.lat_dict['set_ptrs'], self.params)  # BUG???

            def dslash_eo(src):
                return case.dslash_eo(src, i).reshape(src.shape)

            def dslash_oe(src):
                return case.dslash_oe(src, i).reshape(src.shape)

            def cp_dslash(src):

                return cp.array(case.dslash(src.get(), i).reshape(src.shape))
            b_e = self.lat_dict['src_params'][i][define._EVEN_].copy()
            b_o = self.lat_dict['src_params'][i][define._ODD_].copy()
            b__o = np.zeros_like(b_o)
            tmp = np.zeros_like(b_o)
            # b__o=b_o+kappa*D_oe(b_e)
            tmp = dslash_oe(b_e)
            b__o = b_o+self.lat_dict['kappa']*tmp
            # Dslash(x_o)=b__o
            x_o = bistabcg.slover(
                b=cp.array(b__o.flatten()), matvec=cp_dslash, tol=1e-10, max_iter=1000000).reshape(b__o.shape).get()
            # x_e  =b_e+kappa*D_eo(x_o)
            tmp = dslash_eo(x_o)
            x_e = b_e+self.lat_dict['kappa']*tmp
            _b_e = x_e-self.lat_dict['kappa']*case.dslash_eo(x_o, i)
            _b_o = x_o-self.lat_dict['kappa']*case.dslash_oe(x_e, i)
            print(np.linalg.norm(_b_e-b_e)/np.linalg.norm(b_e))
            print(np.linalg.norm(_b_o-b_o)/np.linalg.norm(b_o))
            dest[define._EVEN_] = cp.array(x_e)
            dest[define._ODD_] = cp.array(x_o)
            self.lat_dict['dest_params'].append(dest.get())


case = WilsonCase()
case.run()

In [None]:
case.lat_dict['grid_params']

In [None]:
class LatticeWilsonOperator:
    def __init__(self, nx, ny, nz, dtype):
        self.nx = nx
        self.ny = ny
        self.nz = nz
        self.dtype = dtype
        self.index = 0
        for i, (nx, ny, nz) in enumerate(case.lat_dict['grid_params']):
            if nx == self.nx and ny == self.ny and nz == self.nz:
                self.index = i

    def matvec(self, v):
        return case.dslash(src=v, index=self.index)

    def give_b(self):
        return case.give_b(case.lat_dict['params_params'][self.index])

In [5]:
solver = GeometricMultigrid(nx=case.lat_dict['params_params'][0][define._LAT_Y_], ny=case.lat_dict['params_params'][0][define._LAT_Z_], nz=case.lat_dict['params_params'][0]
                            [define._LAT_X_]*case.lat_dict['params_params'][0][define._LAT_T_]*define._LAT_SC_, dtype=define.dtype(case.params[define._DATA_TYPE_]), tolerance=case.argv[define._TOL_])
solution = solver.solve()
solver.verify_solution(solution)
print(f"收敛迭代次数: {len(solver.convergence_history)}")
print(f"最终残差: {solver.convergence_history[-1]:.2e}")
plt.title(
    f'Adaptive Multigrid Complex Solution Results', fontsize=16)
plt.semilogy(range(1, len(solver.convergence_history) + 1),
             solver.convergence_history, 'b-o', markersize=4)
plt.tight_layout()
solve_time_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
plt.savefig(
    f"Adaptive_Multigrid_Complex_Solution_Results_{solve_time_str}.png", dpi=300)
print("所有测试完成!")
print(f"\n{'='*80}")


开始自适应多重网格复数求解

构建网格层次结构:
  Level 0: 32x32x192
  Level 1: 16x16x192
  Level 2: 8x8x192
  Level 3: 4x4x192
总共 4 层网格

构建各层系统算子:
Level 0 (32x32x192):
Level 1 (16x16x192):
Level 2 (8x8x192):
Level 3 (4x4x192):

开始多重网格迭代:
------------------------------

迭代 1:
V-循环 level 0, 当前层索引: 3, 网格大小: 32x32x192
    前光滑前残差范数: 2.2923e+02
    前光滑...
    前光滑后残差范数: 7.0176e+00
V-循环 level 1, 当前层索引: 2, 网格大小: 16x16x192
    前光滑前残差范数: 2.7246e+00
    前光滑...
    前光滑后残差范数: 7.4497e-01
V-循环 level 2, 当前层索引: 1, 网格大小: 8x8x192
    前光滑前残差范数: 2.2220e-01
    前光滑...
    前光滑后残差范数: 1.7950e-01
V-循环 level 3, 当前层索引: 0, 网格大小: 4x4x192
    前残差范数: 7.7189e-02
    最粗网格直接求解...
    残差范数: 4.6859e-05
    后光滑前残差范数: 8.9080e-02
    后光滑...
    后光滑后残差范数: 7.5834e-02
    后光滑前残差范数: 7.3976e-01
    后光滑...
    后光滑后残差范数: 4.6968e-01
    后光滑前残差范数: 6.1627e+00
    后光滑...
    后光滑后残差范数: 1.8607e+00
  迭代 1 完成，残差范数: 1.8607e+00

迭代 2:
V-循环 level 0, 当前层索引: 3, 网格大小: 32x32x192
    前光滑前残差范数: 1.8607e+00
    前光滑...
    前光滑后残差范数: 1.2376e+00
V-循环 level 1, 当前层索引: 2, 网格大小:

KeyboardInterrupt: 