In [250]:
import pickle
import torch
import os

def load(path):
    with open(path, 'rb') as f:
        obj = pickle.load(f)
    return obj

In [251]:
ranks = [1, 2, 3, 4]
ts = ['prenorm', 'qkv', 'attn_weights', 'attn_output', 'after_residual1', 'postnorm', 'up_proj', 'down_proj', 'after_residual2']

def load_outputs(ranks, ts, dir):
    res_dir = {}
    for rank in ranks:
        for t in ts:
            name = f'rank{rank}_{t}'
            path = os.path.join(dir, f'{name}.pt')
            if os.path.exists(path):
                print(f'Loading {path}')
                tensor = load(path)
                res_dir[name] = tensor
            else:
                raise FileNotFoundError(f'{path} does not exist')
    return res_dir

tpss = ['test_tp', 'test_ga']

results = {
    'test_tp': load_outputs(ranks, ts, 'test_tp'),
    'test_ga': load_outputs(ranks, ts, 'test_ga'),
}


Loading test_tp/rank1_prenorm.pt
Loading test_tp/rank1_qkv.pt
Loading test_tp/rank1_attn_weights.pt
Loading test_tp/rank1_attn_output.pt
Loading test_tp/rank1_after_residual1.pt
Loading test_tp/rank1_postnorm.pt
Loading test_tp/rank1_up_proj.pt
Loading test_tp/rank1_down_proj.pt
Loading test_tp/rank1_after_residual2.pt
Loading test_tp/rank2_prenorm.pt
Loading test_tp/rank2_qkv.pt
Loading test_tp/rank2_attn_weights.pt
Loading test_tp/rank2_attn_output.pt
Loading test_tp/rank2_after_residual1.pt
Loading test_tp/rank2_postnorm.pt
Loading test_tp/rank2_up_proj.pt
Loading test_tp/rank2_down_proj.pt
Loading test_tp/rank2_after_residual2.pt
Loading test_tp/rank3_prenorm.pt
Loading test_tp/rank3_qkv.pt
Loading test_tp/rank3_attn_weights.pt
Loading test_tp/rank3_attn_output.pt
Loading test_tp/rank3_after_residual1.pt
Loading test_tp/rank3_postnorm.pt
Loading test_tp/rank3_up_proj.pt
Loading test_tp/rank3_down_proj.pt
Loading test_tp/rank3_after_residual2.pt
Loading test_tp/rank4_prenorm.pt
Load

In [252]:
def get_partial_results(pt):
    names = [f'rank{rank}_{pt}' for rank in ranks]
    partial = {
        'tp': {name: results[f'test_tp'][name] for name in names},
        'ga': {name: results[f'test_ga'][name] for name in names},
    }
    return names, partial

def all_tensors_equal(tensor_list):
    if not tensor_list:
        return True
    try:
        stacked = torch.stack(tensor_list)  # 要求所有张量可堆叠（shape/dtype/device 一致）
        return (stacked == stacked[0]).all().item()
    except RuntimeError:
        return False

In [253]:
# prenorm
pt = 'prenorm'
names, partial = get_partial_results(pt)

# first layer
def check_prenorm():
    for k, v in partial['tp'].items():
        print(k, v.shape)
    for k, v in partial['ga'].items():
        print(k, v.shape)
    
    for name in names:
        print(torch.allclose(partial['tp'][name], partial['ga'][name]))

# second layer
def check_prenorm():
    for k, v in partial['tp'].items():
        print(k, v.shape)
    for k, v in partial['ga'].items():
        print(k, v.shape)

    prenorm_tp = [partial['tp'][name] for name in names]

    print(all_tensors_equal(prenorm_tp))
    prenorm_ga = torch.cat([partial['ga'][name] for name in names], dim=-2)
    print(prenorm_tp[0], prenorm_ga)
    print(torch.allclose(prenorm_tp[0], prenorm_ga, atol=1e-2))


check_prenorm()

rank1_prenorm torch.Size([1, 140, 4096])
rank2_prenorm torch.Size([1, 140, 4096])
rank3_prenorm torch.Size([1, 140, 4096])
rank4_prenorm torch.Size([1, 140, 4096])
rank1_prenorm torch.Size([1, 35, 4096])
rank2_prenorm torch.Size([1, 35, 4096])
rank3_prenorm torch.Size([1, 35, 4096])
rank4_prenorm torch.Size([1, 35, 4096])
True
tensor([[[-3.9244e-04, -1.1467e-02,  1.2197e-03,  ...,  3.0197e-02,
           4.0741e-03, -2.5726e-02],
         [ 1.6138e-01,  2.2034e-01, -1.9568e-01,  ...,  1.2805e-01,
          -1.5869e-01,  1.3049e-01],
         [-5.9473e-01,  8.5693e-01, -1.1432e-01,  ...,  6.9458e-02,
           2.0422e-01, -3.9948e-02],
         ...,
         [-1.4477e-03,  1.1536e-02,  3.6157e-01,  ...,  1.4319e-01,
          -3.3472e-01,  6.7676e-01],
         [-4.1602e-01, -8.2568e-01, -1.4307e-01,  ..., -5.8301e-01,
          -2.5366e-01,  5.9570e-01],
         [-1.2732e-01,  3.8110e-01, -1.9092e-01,  ..., -2.6123e-01,
          -3.5156e-01,  3.3374e-01]]], dtype=torch.float16) tens

In [254]:
# qkv
pt = 'qkv'
names, partial = get_partial_results(pt)

# first layer (tp)
def check_qkv():
    for k, v in partial['tp'].items():
        print(k, v[0].shape, v[1].shape, v[2].shape)
    for k, v in partial['ga'].items():
        print(k, v[0].shape, v[1].shape, v[2].shape)
    
    for name in names:
        q1, k1, v1 = partial['tp'][name]
        q2, k2, v2 = partial['ga'][name]
        print(torch.allclose(q1, q2, atol=5e-3), torch.allclose(k1, k2, atol=8e-3), torch.allclose(v1, v2, atol=8e-3))

# second layer (galaxy)

check_qkv()

rank1_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank2_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank3_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank4_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank1_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank2_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank3_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
rank4_qkv torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024]) torch.Size([1, 140, 1024])
False False False
False False False
False False False
False False False


In [255]:
# attn_weight
pt = 'attn_weights'
names, partial = get_partial_results(pt)

def check_attn_weight():
    for k, v in partial['tp'].items():
        print(k, v.shape)
    for k, v in partial['ga'].items():
        print(k, v.shape)
    
    for name in names:
        print(torch.allclose(partial['tp'][name], partial['ga'][name], atol=5e-3))

check_attn_weight()

rank1_attn_weights torch.Size([1, 8, 140, 140])
rank2_attn_weights torch.Size([1, 8, 140, 140])
rank3_attn_weights torch.Size([1, 8, 140, 140])
rank4_attn_weights torch.Size([1, 8, 140, 140])
rank1_attn_weights torch.Size([1, 8, 140, 140])
rank2_attn_weights torch.Size([1, 8, 140, 140])
rank3_attn_weights torch.Size([1, 8, 140, 140])
rank4_attn_weights torch.Size([1, 8, 140, 140])
True
True
True
True


In [256]:
# attn_output
pt = 'attn_output'
names, partial = get_partial_results(pt)

def check_attn_output():
    for k, v in partial['tp'].items():
        print(k, v.shape)
    for k, v in partial['ga'].items():
        print(k, v.shape)
    
    ao_tp = sum([partial['tp'][name] for name in names])
    ao_ga = torch.cat([partial['ga'][name] for name in names], dim=-2)
    print(ao_tp)
    print(ao_ga)
    print(torch.allclose(ao_tp, ao_ga, atol=2e-3))

check_attn_output()

rank1_attn_output torch.Size([1, 140, 4096])
rank2_attn_output torch.Size([1, 140, 4096])
rank3_attn_output torch.Size([1, 140, 4096])
rank4_attn_output torch.Size([1, 140, 4096])
rank1_attn_output torch.Size([1, 35, 4096])
rank2_attn_output torch.Size([1, 35, 4096])
rank3_attn_output torch.Size([1, 35, 4096])
rank4_attn_output torch.Size([1, 35, 4096])
tensor([[[ 0.0553, -0.0155,  0.0037,  ..., -0.2419,  0.1694,  0.1458],
         [-0.1049, -0.1730,  0.0094,  ..., -0.1145,  0.1742,  0.3442],
         [-0.0434, -0.4470,  0.1958,  ...,  0.1692, -0.0738,  0.1276],
         ...,
         [ 0.2024, -0.2805, -0.3152,  ...,  0.0519,  0.0266,  0.1704],
         [ 0.5674,  0.2612, -0.3511,  ...,  0.0698,  0.1479,  0.1974],
         [ 0.0319, -0.0128, -0.0957,  ..., -0.2676,  0.0244,  0.2534]]],
       dtype=torch.float16)
tensor([[[ 0.0552, -0.0157,  0.0037,  ..., -0.2419,  0.1696,  0.1458],
         [-0.1046, -0.1736,  0.0090,  ..., -0.1145,  0.1738,  0.3435],
         [-0.0429, -0.4482,  0.1

In [257]:
# after_residual1
pt = 'after_residual1'
names, partial = get_partial_results(pt)

def check_after_residual1():
    for k, v in partial['tp'].items():
        print(k, v.shape)
    for k, v in partial['ga'].items():
        print(k, v.shape)

    residual_tp = [partial['tp'][name] for name in names]
    print(all_tensors_equal(residual_tp))
    residual_ga = torch.cat([partial['ga'][name] for name in names], dim=-2)
    print(residual_tp[0])
    print(residual_ga)
    print(torch.allclose(residual_tp[0], residual_ga, atol=4e-3))

check_after_residual1()

rank1_after_residual1 torch.Size([1, 140, 4096])
rank2_after_residual1 torch.Size([1, 140, 4096])
rank3_after_residual1 torch.Size([1, 140, 4096])
rank4_after_residual1 torch.Size([1, 140, 4096])
rank1_after_residual1 torch.Size([1, 35, 4096])
rank2_after_residual1 torch.Size([1, 35, 4096])
rank3_after_residual1 torch.Size([1, 35, 4096])
rank4_after_residual1 torch.Size([1, 35, 4096])
True
tensor([[[ 0.0495, -0.1875,  0.0240,  ...,  0.2688,  0.2341, -0.2415],
         [ 0.3242,  0.4155, -0.5698,  ...,  0.2715, -0.2754,  0.6943],
         [-1.4766,  1.6270, -0.1108,  ...,  0.3589,  0.4507,  0.0303],
         ...,
         [ 0.1985, -0.2493,  0.7705,  ...,  0.4900, -0.9355,  2.0117],
         [-0.5596, -1.9844, -0.7822,  ..., -1.7188, -0.5840,  1.8242],
         [-0.3687,  1.1904, -0.7637,  ..., -1.1982, -1.1523,  1.3105]]],
       dtype=torch.float16)
tensor([[[ 0.0493, -0.1855,  0.0235,  ...,  0.2688,  0.2333, -0.2402],
         [ 0.3262,  0.4163, -0.5728,  ...,  0.2715, -0.2734,  0.69

In [258]:
# postnorm
pt = 'postnorm'
names, partial = get_partial_results(pt)



def check_postnorm():
    for k, v in partial['tp'].items():
        print(k, v.shape)
    for k, v in partial['ga'].items():
        print(k, v.shape)
    
    pns_tp = [partial['tp'][name] for name in names]
    print(all_tensors_equal(pns_tp))
    pn_ga = torch.cat([partial['ga'][name] for name in names], dim=-2)
    print(pns_tp[0])
    print(pn_ga)
    print(torch.allclose(pns_tp[0], pn_ga, atol=4e-3))

check_postnorm()

rank1_postnorm torch.Size([1, 140, 4096])
rank2_postnorm torch.Size([1, 140, 4096])
rank3_postnorm torch.Size([1, 140, 4096])
rank4_postnorm torch.Size([1, 140, 4096])
rank1_postnorm torch.Size([1, 35, 4096])
rank2_postnorm torch.Size([1, 35, 4096])
rank3_postnorm torch.Size([1, 35, 4096])
rank4_postnorm torch.Size([1, 35, 4096])
True
tensor([[[ 0.0071, -0.0269,  0.0035,  ...,  0.0375,  0.0316, -0.0342],
         [ 0.1134,  0.1453, -0.2020,  ...,  0.0923, -0.0906,  0.2394],
         [-0.5449,  0.6006, -0.0415,  ...,  0.1289,  0.1565,  0.0110],
         ...,
         [ 0.0651, -0.0818,  0.2563,  ...,  0.1564, -0.2888,  0.6509],
         [-0.1840, -0.6523, -0.2607,  ..., -0.5498, -0.1808,  0.5918],
         [-0.1018,  0.3289, -0.2137,  ..., -0.3218, -0.2996,  0.3569]]],
       dtype=torch.float16)
tensor([[[ 0.0071, -0.0267,  0.0034,  ...,  0.0376,  0.0316, -0.0341],
         [ 0.1140,  0.1455, -0.2030,  ...,  0.0923, -0.0900,  0.2399],
         [-0.5479,  0.6016, -0.0423,  ...,  0.1301,

In [259]:
# up_proj
pt = 'up_proj'
names, partial = get_partial_results(pt)

def check_up_proj():
    # for k, v in partial['tp'].items():
    #     print(k, v.shape)
    # for k, v in partial['ga'].items():
    #     print(k, v.shape)
    
    tp_up = torch.cat([partial['tp'][name] for name in names], dim=-1)
    ga_up = torch.cat([partial['ga'][name] for name in names], dim=-1)
    print(tp_up.shape)
    print(ga_up.shape)
    print(tp_up)
    print(ga_up)
    print(torch.allclose(tp_up, ga_up, rtol=0.1, atol=1e-5))

check_up_proj()

torch.Size([1, 140, 11008])
torch.Size([1, 140, 11008])
tensor([[[-9.0942e-03,  3.0701e-02, -1.8829e-02,  ..., -7.2365e-03,
          -1.0155e-02,  5.0621e-03],
         [ 6.6910e-03, -4.0894e-03, -1.2286e-01,  ..., -6.5369e-02,
           2.3767e-01, -6.8298e-02],
         [ 3.2496e-04, -2.1716e-01, -1.8591e-01,  ..., -2.5122e-01,
           5.7324e-01,  2.1500e-02],
         ...,
         [ 4.8126e-02, -1.0361e+00, -5.7007e-02,  ..., -7.0020e-01,
           5.2148e-01,  1.9690e-01],
         [ 2.7725e-02,  2.9861e-02,  4.9292e-01,  ..., -7.5195e-02,
           5.5225e-01, -1.3879e-01],
         [ 3.6523e-01, -4.4098e-02,  1.1017e-01,  ...,  1.7471e-02,
           7.7100e-01, -4.7302e-02]]], dtype=torch.float16)
tensor([[[-9.0485e-03,  3.0518e-02, -1.8524e-02,  ..., -7.2021e-03,
          -1.0086e-02,  5.0888e-03],
         [ 6.6719e-03, -4.1733e-03, -1.2256e-01,  ..., -6.5491e-02,
           2.3792e-01, -6.8298e-02],
         [ 4.2343e-04, -2.1655e-01, -1.8567e-01,  ..., -2.5146e-01,

In [260]:
# down_proj
pt = 'down_proj'
names, partial = get_partial_results(pt)

# preivous layers
# def check_down_proj():
# # def check_attn_output():
#     # for k, v in partial['tp'].items():
#     #     print(k, v.shape)
#     # for k, v in partial['ga'].items():
#     #     print(k, v.shape)
    
#     dp_tp = sum([partial['tp'][name] for name in names])
#     dp_ga = torch.cat([partial['ga'][name] for name in names], dim=-2)
#     print(dp_tp)
#     print(dp_ga)
#     print(torch.allclose(dp_tp, dp_ga, atol=1e-3))

# last layer
def check_down_proj():
    dp_tp = sum([partial['tp'][name] for name in names])
    dp_ga = sum([partial['ga'][name] for name in names])
    print(dp_tp)
    print(dp_ga)
    print(torch.allclose(dp_tp, dp_ga, rtol=2e-1))

check_down_proj()


tensor([[[ 3.7817e-01,  7.4036e-02, -1.8311e-03,  ..., -2.4097e-01,
          -2.3267e-01,  3.0566e-01],
         [-5.8008e-01,  9.7070e-01, -3.5840e-01,  ...,  3.0566e-01,
          -2.3486e-01,  7.9248e-01],
         [ 3.8916e-01, -1.3164e+00,  6.9385e-01,  ...,  7.0312e-02,
           6.2158e-01, -6.8750e-01],
         ...,
         [ 2.8672e+00,  5.4492e-01, -2.4082e+00,  ...,  3.8330e-01,
          -5.0488e-01, -4.0771e-01],
         [-2.3633e-01,  1.4531e+00,  7.3145e-01,  ...,  4.3457e-01,
          -1.4832e-01, -4.0332e-01],
         [-1.7676e-01,  1.2402e+00, -2.4512e-01,  ...,  1.5781e+00,
          -2.7246e-01, -3.0762e-02]]], dtype=torch.float16)
tensor([[[ 3.7769e-01,  7.3975e-02, -2.0752e-03,  ..., -2.4072e-01,
          -2.3267e-01,  3.0518e-01],
         [-5.8057e-01,  9.7119e-01, -3.5889e-01,  ...,  3.0420e-01,
          -2.3438e-01,  7.9248e-01],
         [ 3.9160e-01, -1.3184e+00,  6.9434e-01,  ...,  6.9946e-02,
           6.2256e-01, -6.8750e-01],
         ...,
    

In [263]:
# after_residual2
pt = 'after_residual2'
names, partial = get_partial_results(pt)

# last_layer
def check_after_residual2():
    # for k, v in partial['tp'].items():
    #     print(k, v.shape)
    # for k, v in partial['ga'].items():
    #     print(k, v.shape)
    ar2_tp = [partial['tp'][name] for name in names]
    ar2_ga = [partial['ga'][name] for name in names]
    print(all_tensors_equal(ar2_tp))
    print(all_tensors_equal(ar2_ga))
    print(ar2_tp[0])
    print(ar2_ga[0])

check_after_residual2()

True
True
tensor([[[ 4.2749e-01, -1.1340e-01,  2.2141e-02,  ...,  2.7832e-02,
           1.5869e-03,  6.3965e-02],
         [-2.5635e-01,  1.3867e+00, -9.2822e-01,  ...,  5.7764e-01,
          -5.1025e-01,  1.4873e+00],
         [-1.0879e+00,  3.1152e-01,  5.8252e-01,  ...,  4.2920e-01,
           1.0723e+00, -6.5723e-01],
         ...,
         [ 3.0664e+00,  2.9517e-01, -1.6377e+00,  ...,  8.7354e-01,
          -1.4414e+00,  1.6055e+00],
         [-7.9590e-01, -5.3125e-01, -5.0781e-02,  ..., -1.2842e+00,
          -7.3242e-01,  1.4209e+00],
         [-5.4590e-01,  2.4297e+00, -1.0088e+00,  ...,  3.7891e-01,
          -1.4248e+00,  1.2793e+00]]], dtype=torch.float16)
tensor([[[ 4.2676e-01, -1.1163e-01,  2.1393e-02,  ...,  2.7954e-02,
           6.1035e-04,  6.4819e-02],
         [-2.5488e-01,  1.3867e+00, -9.3164e-01,  ...,  5.7617e-01,
          -5.0781e-01,  1.4873e+00],
         [-1.0928e+00,  3.1055e-01,  5.8105e-01,  ...,  4.3213e-01,
           1.0742e+00, -6.5674e-01],
        