In [1]:
import torch
import transformer_cpp

# Test Multi-Head Attention
batch_size, seq_len, embedding_dim, num_heads = 2, 8, 256, 8
query = torch.randn(batch_size, seq_len, embedding_dim).cuda()
key = torch.randn(batch_size, seq_len, embedding_dim).cuda()
value = torch.randn(batch_size, seq_len, embedding_dim).cuda()
mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).cuda()

output = transformer_cpp.multi_head_self_attention(query, key, value, mask, num_heads)
print("Attention Output Shape:", output.shape)

# Test MoE
input = torch.randn(batch_size, seq_len, embedding_dim).cuda()
expert_weights = [torch.randn(embedding_dim, 512).cuda() for _ in range(4)]
expert_biases = [torch.randn(512).cuda() for _ in range(4)]
moe_output = transformer_cpp.mixture_of_experts(input, expert_weights, expert_biases, k=2)
print("MoE Output Shape:", moe_output.shape)

# Test Transformer Model
num_heads_list = [8, 8]
num_experts_list = [4, 4]
expert_hidden_dim = 512
transformer_output = transformer_cpp.transformer_model(input, num_heads_list, num_experts_list, expert_hidden_dim)
print("Transformer Output Shape:", transformer_output.shape)

Attention Output Shape: torch.Size([2, 8, 256])


TypeError: mixture_of_experts(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: list[torch.Tensor], arg2: list[torch.Tensor], arg3: int) -> torch.Tensor

Invoked with: tensor([[[ 1.2716,  0.0791,  0.3014,  ..., -0.0559, -3.1170, -0.4878],
         [ 0.1177,  0.7039,  0.6225,  ..., -0.1891, -0.8746,  0.2392],
         [ 0.2799,  0.1586,  0.0927,  ..., -0.5317, -0.4224,  0.0466],
         ...,
         [-0.3291, -2.6983,  1.5693,  ...,  0.3031, -0.2768,  0.6406],
         [ 0.1456,  1.6887, -0.9113,  ..., -1.2413, -1.2076,  1.5084],
         [ 0.3808,  0.1884,  1.8818,  ...,  1.6507, -2.0407, -0.5481]],

        [[ 1.2101,  1.3958, -1.3185,  ...,  0.2368, -0.6657, -1.6289],
         [-1.2831,  1.3247,  0.8549,  ..., -0.3647, -0.4670,  0.9291],
         [-1.7751, -0.6742, -0.0553,  ...,  0.9023, -1.8265,  0.2921],
         ...,
         [ 0.4207, -0.3194, -0.3981,  ...,  1.8702,  1.8472,  0.5217],
         [-0.5587,  0.3636,  1.8041,  ...,  1.7110,  0.4525, -0.6995],
         [ 1.9115, -1.3403,  1.4462,  ..., -0.1928,  0.2030, -0.5873]]],
       device='cuda:0'), [tensor([[-1.1645, -0.9868,  1.9448,  ...,  0.6447,  0.3006, -2.4518],
        [-0.0669, -0.3269, -0.1815,  ..., -0.7026, -0.9414, -0.6791],
        [ 1.4593, -0.3270,  0.2311,  ...,  1.2688, -0.0443, -1.6987],
        ...,
        [-0.0562, -1.1487, -0.8611,  ...,  0.4111, -0.5700, -2.2571],
        [ 0.3267, -0.8375,  0.4823,  ..., -0.0710,  2.3593, -0.3627],
        [ 0.2861, -2.0299,  0.1663,  ..., -0.5921,  1.2068, -0.3433]],
       device='cuda:0'), tensor([[-0.6387, -0.7235, -0.1626,  ..., -0.7400, -1.8495,  0.2235],
        [ 0.5579,  1.2547,  0.3822,  ...,  3.3156, -0.2934, -0.8670],
        [ 2.0932,  1.9412,  1.1805,  ...,  0.6811,  1.7616, -0.4403],
        ...,
        [ 1.9042,  0.2068, -0.0833,  ...,  0.9077, -0.3426,  0.2985],
        [ 1.6584, -0.2692, -0.0209,  ..., -1.1988,  0.3827,  1.0190],
        [ 0.4417,  0.1634, -2.7271,  ...,  1.5670, -0.7958,  1.3405]],
       device='cuda:0'), tensor([[-0.3983, -0.7925, -0.4574,  ...,  0.0552,  0.2188,  0.3971],
        [-1.2899, -1.2663,  0.3558,  ...,  0.3936,  0.2137,  1.1953],
        [-0.8238, -0.1220,  0.0194,  ..., -1.6054,  0.7061,  1.2862],
        ...,
        [ 0.5698,  0.4595, -0.4823,  ...,  1.3871,  0.2107, -1.1829],
        [ 0.7284, -0.9407, -1.0973,  ..., -0.0498, -0.3411, -0.5730],
        [ 1.0840,  0.0986,  2.2815,  ..., -0.8109, -0.3137,  0.4051]],
       device='cuda:0'), tensor([[ 0.5646, -0.3069,  0.8318,  ..., -0.7398, -0.6328,  1.0514],
        [-0.6261,  0.6954, -1.8092,  ...,  0.2934,  0.4148, -0.2067],
        [-1.1490,  0.0681,  0.6838,  ..., -1.0542, -1.7663,  0.8216],
        ...,
        [ 0.2163,  1.2203,  0.7757,  ..., -0.8146, -0.3637,  0.2147],
        [-0.2988,  1.7238,  0.3818,  ...,  2.0650, -0.1951, -1.8099],
        [ 1.5450, -3.4767, -0.5990,  ..., -0.1125,  1.0594,  1.3566]],
       device='cuda:0')], [tensor([ 2.1314e+00,  9.8048e-01, -1.9978e+00,  8.4457e-01, -1.4488e+00,
         2.9103e-01, -1.5800e-01,  1.9601e+00, -2.8020e+00,  1.2415e+00,
        -1.3430e+00,  9.2835e-02, -1.1860e+00, -1.4540e+00, -9.9721e-01,
        -4.7136e-01,  4.6518e-01, -5.7338e-01, -2.2862e+00, -1.1351e-01,
         3.2860e-01,  1.0175e-02, -8.1090e-01, -1.2536e+00, -2.1338e-01,
        -2.4746e-01, -5.3608e-01,  1.1566e+00,  1.0427e+00, -3.2386e+00,
        -2.3381e-01,  1.0605e+00,  5.7956e-01, -2.7741e+00,  3.9278e-01,
         3.2381e-01,  6.0492e-01, -2.6745e-02,  4.7779e-01, -3.6416e-01,
         2.9884e-01,  3.2342e-01,  5.0958e-01,  1.4125e+00, -3.1193e-01,
        -4.4120e-01,  5.1658e-01, -3.5486e-01,  8.4731e-02,  1.0429e+00,
         2.7799e-01, -1.0768e+00, -5.8359e-01,  9.5306e-01,  1.5716e+00,
        -7.7447e-03,  1.0660e+00,  3.4591e-01, -4.1339e-01, -4.7356e-01,
         2.7366e-01,  1.1373e+00, -1.5685e+00,  4.6172e-01,  8.7261e-01,
        -9.9073e-01,  1.7604e+00, -7.7312e-01,  9.2283e-01,  5.6286e-01,
         2.4393e-01, -5.4526e-01,  7.0972e-01,  5.4975e-03,  2.0546e-01,
         6.2925e-01, -7.9299e-01,  4.0092e-01,  5.4875e-01, -3.3249e-01,
         5.9599e-02,  4.6815e-01, -1.0765e+00,  7.4393e-01, -1.2519e-01,
         2.3668e+00, -3.1369e-01,  8.0113e-01, -1.0222e+00, -2.6818e+00,
         5.2535e-02,  3.7121e-01,  4.7436e-02, -1.6034e+00,  1.0280e+00,
         1.1530e+00,  6.2829e-01,  7.7989e-01, -6.8269e-01,  1.7151e+00,
         2.3000e+00, -9.2818e-01,  1.2112e-01, -4.5644e-01,  1.1977e-01,
        -9.5165e-01,  3.3560e-01,  1.5429e-01, -1.3525e-01,  5.6980e-01,
         8.4226e-01,  6.7281e-02, -7.3456e-01, -1.4118e+00,  1.1672e+00,
        -2.3096e-01, -1.0627e-01,  2.3102e-03, -9.5416e-01, -1.0252e+00,
        -9.2760e-01,  7.3850e-01,  1.1787e+00,  8.6740e-01, -1.8952e+00,
         3.8093e-01, -1.8061e-02,  1.1030e+00,  2.4809e+00, -3.7159e-01,
        -2.0118e+00, -1.1864e+00,  1.2147e+00, -1.7257e+00,  1.2789e+00,
         1.3565e+00, -1.1379e+00,  1.0364e+00,  3.1803e-01,  1.8804e+00,
        -1.9007e+00, -9.7327e-01, -7.3145e-02, -1.1518e+00, -8.6983e-03,
         1.0150e-01, -7.7754e-01, -9.4932e-01, -1.2353e+00, -1.0995e-01,
         1.9386e-01,  1.7199e+00,  7.6792e-01, -9.1592e-01, -2.4664e-01,
        -4.0888e-01,  4.1677e-01,  1.4087e+00, -6.6212e-01, -3.7233e-01,
         1.3834e+00, -1.2668e+00,  3.3217e-01, -9.4202e-01, -2.0409e-01,
        -4.8256e-01, -4.4195e-01,  2.5578e-01,  1.0128e+00, -1.2120e+00,
        -8.1857e-01, -1.4591e-01, -2.1581e+00,  2.4048e+00, -4.4694e-02,
         1.6392e-01,  7.7301e-01,  5.8033e-01, -1.5249e+00,  2.8593e+00,
         7.6720e-01,  1.3316e-01,  8.4350e-01, -6.2354e-01,  1.3549e+00,
        -1.9293e+00, -3.0371e-01,  2.1342e-01, -9.4439e-01, -2.9927e-01,
        -2.7190e-01,  1.8947e+00, -5.6138e-01, -1.3177e-01, -8.5121e-01,
         3.0948e-01,  5.7887e-01, -6.2687e-01, -3.2177e-01, -1.1647e+00,
        -8.2643e-01, -1.4149e+00,  9.4821e-01, -1.3676e-01,  2.0907e+00,
         8.3701e-01,  2.8342e+00, -2.4759e-01,  9.4325e-02, -4.9226e-01,
         8.0352e-01,  1.9435e-01,  2.4708e-01, -5.3832e-02,  1.9700e+00,
         2.5213e+00, -1.6603e+00, -8.5949e-02, -4.5418e-01,  7.8949e-01,
        -1.9539e+00,  3.1227e-01, -6.5860e-01,  1.1963e+00,  1.7287e+00,
         3.4045e-01,  1.3139e+00, -1.2686e+00, -4.4250e-01,  9.7393e-01,
        -1.6138e-02, -4.8616e-01, -4.8859e-01, -8.8965e-01,  1.6543e-01,
         4.9452e-01,  7.5812e-01, -5.3940e-01,  5.1943e-01, -1.0053e+00,
         2.8981e+00,  1.7401e-01, -1.1174e+00,  1.8751e+00, -3.3193e-01,
        -5.1406e-01, -8.6872e-02, -2.7300e-01, -7.9391e-01,  1.2769e-01,
         2.0589e-01,  6.5605e-01, -1.1305e-01, -1.7438e+00,  1.0319e+00,
        -2.8927e-01, -8.2411e-02, -1.5295e+00, -2.4999e-01,  6.2113e-01,
        -5.0257e-01, -1.4107e+00, -1.2195e+00, -1.5958e-01, -1.6522e+00,
         1.5103e-01, -3.8429e-01, -6.0575e-02,  1.4243e+00, -7.0424e-01,
        -5.4975e-01,  1.3520e+00, -1.2457e+00, -1.3203e+00,  3.1314e-01,
         1.9715e-01, -1.1040e+00, -1.2573e+00,  4.8694e-01,  9.1593e-01,
         2.3304e-01, -2.7306e-01,  5.4994e-01, -1.7427e+00, -1.0622e+00,
         6.1241e-01,  1.6422e+00,  3.0837e-01,  1.0961e+00,  2.4986e-01,
         1.5916e+00, -2.7938e-01,  1.9746e+00,  1.5292e-01, -1.4616e+00,
         1.3324e+00,  1.1668e+00, -1.2881e+00, -7.0098e-01,  5.5991e-01,
         2.2512e+00,  1.1015e+00, -5.8363e-01, -8.4254e-01,  2.4024e-02,
         2.5145e-01, -1.5024e-01,  2.6399e+00, -1.0090e+00, -2.7919e+00,
        -1.1092e+00,  1.0443e+00, -1.3391e+00,  1.1750e+00,  1.1468e-01,
        -1.1711e+00,  2.9722e+00,  2.4692e+00,  1.2236e+00,  7.6377e-01,
        -8.7554e-01,  7.1764e-01, -2.5666e+00,  1.1135e-01,  6.1549e-01,
        -2.3669e-01, -8.9722e-01, -6.3254e-01,  1.6662e-01,  5.8564e-01,
        -6.3279e-01, -9.7789e-01, -8.6514e-01,  1.1818e+00, -1.4691e+00,
         1.1533e+00,  3.9072e-01,  3.8741e-01, -3.0837e-01,  1.6297e+00,
        -1.7564e+00, -8.3587e-01,  8.4853e-01,  5.3675e-01, -1.3554e+00,
        -4.3652e-01,  1.6532e+00,  3.0116e-01, -2.3875e-01, -8.7757e-01,
        -4.1314e-01, -1.3721e+00,  1.2119e+00,  1.7879e+00,  2.5834e-01,
         4.7925e-01, -4.0219e-01,  9.0784e-01,  6.5405e-01,  1.5079e-02,
         2.3062e-01, -4.9747e-01,  4.6767e-01, -1.6406e-01, -5.8116e-01,
        -1.9405e+00,  1.3544e+00,  1.5719e+00, -8.3671e-02, -5.7923e-01,
         1.2245e+00, -1.6807e-01, -1.1539e+00,  2.5886e-01, -8.7154e-01,
        -1.8109e+00, -2.1699e-01,  4.8152e-01, -1.7371e+00, -1.5696e+00,
         5.7711e-01, -4.2903e-01,  9.7449e-01,  1.0631e-02,  5.5964e-01,
        -1.8918e-01, -1.4932e+00,  8.2952e-02,  3.4341e-01, -8.9248e-01,
         2.2347e-01, -4.7273e-01, -4.3147e-02,  2.9740e-01,  4.3178e-01,
         1.4677e+00,  3.0317e+00,  1.4971e+00,  7.5494e-01,  9.2214e-01,
         1.7857e+00, -8.7640e-01, -7.4991e-01, -1.3095e+00,  1.2396e+00,
         1.5165e+00, -1.3747e-01, -1.5039e+00,  9.8377e-01,  8.9170e-01,
        -3.1595e-01, -6.8174e-01,  5.2977e-01,  3.0410e-01, -7.6748e-02,
        -2.0798e+00, -3.7951e-01, -2.0812e-01, -2.4830e-01,  8.3938e-01,
         1.5063e+00,  1.1416e-01,  1.4677e+00, -1.5824e+00, -8.1742e-01,
         6.1123e-02,  1.0300e-01,  5.8305e-01,  3.2208e-01,  7.9215e-01,
         2.3527e+00,  3.3290e-01,  1.1048e+00,  6.8093e-01,  4.5606e-01,
        -2.2133e+00, -5.8659e-01,  2.4141e+00, -2.1235e-01, -6.7960e-01,
         7.2723e-01,  4.4894e-01,  3.1231e-01, -7.7391e-01, -3.6108e-01,
         6.8751e-01, -3.0216e-01, -1.4938e+00, -1.2024e+00, -1.1580e+00,
         6.5644e-01,  1.1952e+00,  1.6717e+00,  8.9605e-01,  1.1596e+00,
        -1.0875e+00, -1.0987e+00,  4.9703e-01, -1.8331e-01, -1.4844e-01,
         1.3019e+00, -4.4058e-01,  4.0067e-01, -1.7244e+00,  5.7690e-01,
        -1.4810e+00,  1.3245e+00, -2.1345e-01, -2.1490e-01, -2.9793e-01,
        -5.6564e-01,  9.0298e-01,  1.3806e-01, -8.8979e-02,  9.5105e-01,
        -1.1837e+00, -3.4236e-01, -9.8399e-01, -9.8322e-01, -5.3815e-01,
         2.8419e-02, -1.2717e+00,  1.7064e+00,  1.6217e+00, -1.2654e+00,
         5.4157e-02, -2.6470e+00, -1.4813e+00,  8.8589e-01, -7.4451e-01,
        -1.8804e-01,  4.2779e-01,  8.7643e-02, -1.6250e+00, -4.1927e-02,
        -7.5984e-01,  1.1659e-01, -4.5601e-01,  7.0561e-01,  1.8396e-01,
        -5.6386e-01,  8.9338e-01, -1.7566e+00,  6.9294e-01,  1.0991e+00,
        -1.0712e+00, -9.4509e-02,  5.1544e-01, -4.0663e-01,  1.0554e+00,
         2.6602e-01,  1.3370e+00], device='cuda:0'), tensor([-0.0699, -0.9345,  0.7950,  0.8601, -0.3315,  1.6901,  0.7024, -0.7724,
        -0.7583, -0.6228, -0.6766,  2.2739,  0.6831, -0.8968, -0.3911, -1.8756,
         0.4721, -0.3791,  1.5270, -1.4338,  1.5458,  0.0323,  0.1091,  0.3440,
        -1.4501, -1.8873,  1.5766,  0.8359,  0.4349, -0.2657,  0.0625, -0.4172,
         1.2879, -0.9388,  0.8542, -1.9581,  1.3596,  0.5933,  0.9576, -0.1531,
         0.2777, -2.2216,  1.7561, -0.6508,  0.5140, -2.8418, -0.9115,  0.2625,
         1.0904, -1.2333,  0.3058,  0.8168, -1.4038,  0.8580, -0.1239,  0.4964,
        -0.5806, -0.8882,  1.6778, -0.1601,  0.0289,  0.8980, -1.0123, -0.1206,
        -0.7102, -0.5786, -0.1524,  2.4190,  1.3246,  1.4140, -3.0109, -2.0634,
        -0.3027,  0.8158, -0.2092, -0.4603,  1.0300,  1.1300,  1.2685,  0.0079,
        -0.6938,  0.0904, -0.3365, -0.1766, -0.2427,  0.6055, -0.7178, -1.2181,
         0.1505, -0.0470,  1.1983, -0.2715,  0.0597,  1.0751,  2.1848, -1.4275,
        -0.0326,  0.5921,  0.4007,  0.5095, -0.3991,  0.1187,  2.0611, -2.6236,
        -0.8123,  1.7212, -0.2347,  0.1355, -0.2277,  1.0655, -0.3462,  0.4389,
         0.6781, -2.1641, -0.5717,  0.5243, -0.1592,  0.9059, -0.6091, -1.1318,
         0.8175,  1.4012,  0.9018, -0.1154,  1.1777, -0.0881,  0.6217, -1.7204,
         1.2351,  1.4072,  0.2830, -2.5573,  0.1071, -0.1553, -1.2336,  1.2116,
        -0.4629,  1.0112, -1.6393,  0.1816,  0.3251, -0.3926,  0.0951, -2.0852,
         0.4868,  0.4979,  1.2728,  0.6232,  0.5472, -0.9913, -1.1029, -0.6398,
         1.0435,  0.0559,  0.5227, -1.2039,  1.2969, -1.0831, -0.0397,  1.4149,
         0.8071, -0.1652, -0.4068, -1.0976, -0.8987,  0.3723,  0.0447,  0.7615,
         0.0732,  0.6205, -1.1743,  2.6020, -0.1271, -0.0052, -0.6671,  1.6124,
        -0.1430,  2.1989, -1.9521,  0.2691,  0.9896, -1.1368,  0.6074, -0.1743,
         0.5059,  1.8516, -0.9927,  0.2449, -0.0495,  0.4485,  0.2943,  1.1098,
        -1.3600,  1.1673,  0.5840,  0.1786,  1.1561, -0.1280, -0.7711,  1.0751,
         0.9259, -0.9612,  0.5178, -0.6046, -0.4498,  0.5809,  1.3995,  0.7144,
         0.6597, -0.7407,  1.9446, -0.2083, -0.8569,  1.2912,  0.4873, -0.6642,
        -1.4881, -0.3054, -0.1096, -0.4275, -0.9645,  0.9363,  1.1617, -0.6436,
        -0.0462, -0.9648,  1.1952, -0.7914, -0.7759, -0.9790, -0.1931,  1.4289,
         0.6038, -1.4021,  0.0157, -0.8005, -0.8393, -0.7713, -0.4186,  0.0942,
        -0.8679,  1.3694,  0.7403,  1.3712,  0.4910, -0.5931,  0.5429, -1.0725,
         0.3442,  0.6317, -0.9498,  0.1186,  0.0239, -1.7281, -1.9096, -0.0369,
        -1.5204,  0.6712, -0.7110, -1.9271,  0.0084,  0.1951,  0.0622,  0.5617,
         0.6017, -0.9752, -2.1599,  0.5793,  0.3797, -1.0667, -1.3228, -0.2481,
         0.2336,  0.8697,  0.2177, -0.9167,  0.4563,  0.5358, -0.6521, -2.4893,
         0.4850, -0.0952,  1.5845,  1.1684, -0.3520, -0.2375,  0.8533, -2.6513,
        -0.4400,  1.0101, -0.4580, -2.4688, -0.9230, -0.4087,  0.3138, -1.4201,
         0.3860,  0.1604, -1.2472, -1.8029,  0.3761, -1.3686,  0.1928, -0.6756,
         0.0738, -1.2852,  0.4881, -2.0173, -0.9888, -0.0861,  0.7659, -0.2439,
        -0.5942,  0.1297,  0.5123,  0.6363, -0.9319,  0.5770,  0.0301,  0.5222,
         1.2587,  0.6318, -0.0253,  1.5601,  1.6591, -0.9839,  0.0922, -0.2363,
         0.2105, -0.7974, -1.1948,  0.2275,  1.0029,  0.3446, -0.2720, -0.2498,
        -0.4425,  1.5385, -0.2255, -0.5378, -1.6497, -1.1120, -0.9447, -1.5615,
         0.2739,  0.2050, -0.0041,  1.6473,  0.4976, -0.0850, -0.2416, -0.7288,
        -1.4243, -1.1103, -0.6998, -0.3004,  0.2715, -1.2155,  1.2070, -0.1017,
        -0.7297, -0.6405, -0.6678,  0.3806, -0.4871, -0.7978,  0.5085,  0.1511,
         0.4489,  0.5990, -1.7008, -0.1141, -0.8733, -0.1456,  0.4220,  0.4765,
         1.6769,  2.0592,  0.5703, -0.1795,  0.7006,  0.1263,  0.4934, -1.0502,
        -1.6785, -0.0329,  0.5945, -0.6865,  0.5914,  0.0523, -0.8143, -0.7001,
        -0.5099, -0.4031, -2.1630,  0.4883, -0.0111,  1.6418, -0.9013,  1.0738,
         1.7338,  0.0540,  1.1000, -0.3209, -0.9817,  0.3850, -0.1401, -1.1952,
        -0.9354,  0.3270,  0.5946,  0.4581, -1.8081,  0.4573, -2.2100, -1.0841,
        -0.4503,  0.0401, -0.6986, -0.1001, -0.5585,  1.3008,  0.5824, -0.8576,
        -1.1016,  0.3188, -2.1731, -0.3895, -0.4621, -1.3600,  0.7989, -1.8426,
        -0.7449, -0.0374, -0.8976, -0.1621, -1.2891, -0.2370,  2.3841, -0.2434,
        -1.4657,  0.1413,  0.1159,  0.0140, -1.7804,  0.7773,  0.4508, -0.4531,
         1.1835, -0.1742, -0.9040, -2.1000, -1.7178,  1.4413,  0.8778,  0.4435,
        -0.3765,  0.4465, -0.4878, -0.3186,  1.2043, -0.3013,  0.7574,  0.2820,
         0.4442, -0.3433, -0.3808, -0.7802, -0.6060, -1.8374, -0.0133,  0.2346,
        -0.3707,  0.1967, -0.8866,  0.0417, -0.6688,  1.2153, -0.3754,  0.1536,
        -0.3040,  0.0488,  0.2865,  1.2829, -1.3549, -0.6379,  2.0531,  0.2008,
        -0.0149, -0.3349,  1.1862, -1.1148,  0.1595,  2.5876,  0.9832, -1.5014,
        -1.5227,  0.0212, -0.2347,  0.8335, -0.6467, -1.1074,  0.8113,  0.4979,
         0.2459, -0.7718, -0.4346, -0.5569,  0.5764, -1.8747, -0.9246,  0.7513],
       device='cuda:0'), tensor([ 1.1550e+00, -3.1568e-01, -1.0681e+00,  1.5937e+00,  2.5037e+00,
         4.4685e-01,  7.8901e-01, -1.7580e-01, -1.3781e+00,  1.3402e-01,
        -1.0050e+00, -1.9760e-01, -1.0061e+00, -2.0665e+00,  5.3659e-01,
         4.8254e-01,  7.4519e-01, -2.1283e+00, -6.7953e-01, -1.5969e+00,
         8.4656e-01, -1.1199e+00,  1.6691e-01, -1.9056e-01, -3.7388e-02,
        -1.7429e-01,  8.2975e-01,  1.0044e+00,  2.0091e-01, -6.1632e-03,
         6.1120e-01, -7.6903e-01,  5.4002e-01, -9.6937e-01,  8.7030e-01,
        -1.7734e+00,  6.8457e-01, -1.2797e+00,  8.0323e-03, -4.6556e-01,
        -2.5287e-01,  4.3772e-01, -8.6073e-01, -2.9283e-01,  1.0062e+00,
         1.0785e+00, -5.0140e-01,  2.5094e-01,  6.5774e-01,  4.0023e-01,
        -5.4432e-01,  3.0572e-01, -1.3006e+00, -3.0700e-01,  9.2446e-01,
        -7.0759e-01, -2.1546e+00,  4.1818e-01,  4.7681e-01, -1.1844e+00,
        -2.8960e-02, -2.4670e-01, -3.7573e-01, -1.2009e-01,  6.6210e-02,
         3.3480e-01, -4.1880e-01,  1.6180e-01,  7.7261e-01, -1.6968e+00,
        -8.3982e-01, -7.5780e-01,  1.7200e+00, -5.6223e-01,  8.5973e-02,
         1.1104e+00, -2.6183e-01, -2.1574e-02, -7.4266e-01,  1.9469e-01,
         3.5672e-01,  5.7290e-01,  6.3094e-02, -1.0689e+00, -9.9635e-01,
        -1.6735e+00,  5.2042e-01, -1.9913e+00,  1.6228e-01,  7.7375e-01,
        -4.1587e-01,  6.5839e-01, -1.9424e-01,  6.1161e-02, -1.0451e-01,
        -8.2465e-01, -1.4780e+00, -5.3572e-01, -2.3412e-01,  5.1447e-01,
         1.0385e+00, -4.3098e-01,  1.1185e-02,  5.1732e-01,  2.1632e+00,
        -1.7020e+00,  4.2388e-01, -8.4792e-01,  6.5251e-01, -3.8599e-01,
        -3.6007e-01,  2.4882e-01,  1.8742e-01, -1.1384e+00,  1.6621e-01,
         9.7481e-01,  2.6380e+00, -6.8583e-01,  3.0980e-01, -3.4337e-01,
        -6.4537e-01,  7.7396e-02,  4.1175e-01, -1.6681e+00, -6.3440e-01,
         1.3185e-01, -1.1120e+00, -1.4456e+00, -4.3921e-01, -1.2204e+00,
         2.0129e+00,  9.8940e-01,  1.1071e+00, -2.7050e-02,  2.0693e+00,
         5.6667e-01, -1.8604e-01,  2.0796e-01, -2.7443e-01, -1.5496e+00,
        -4.9670e-01, -4.8706e-01, -1.6585e-01, -1.1878e+00,  1.6439e-01,
         2.0101e+00, -1.5759e+00, -1.7746e+00, -2.4706e+00,  8.8504e-01,
        -2.1732e-01, -6.4840e-01, -7.4837e-01,  1.4037e+00,  1.1812e+00,
        -3.5393e-01, -7.4942e-01,  7.0084e-01,  1.1506e+00,  2.2914e+00,
         3.2021e-02,  1.1164e+00,  8.6683e-01, -1.4031e-01,  5.0170e-01,
        -6.9221e-01,  7.5841e-01,  1.4134e+00,  1.4388e+00, -8.9351e-01,
         1.6069e+00, -1.8043e+00,  1.0421e+00, -2.2974e+00, -3.6776e-01,
         9.1429e-01, -4.8772e-02, -9.4127e-01, -3.5403e-01,  3.5928e-01,
         1.3678e+00, -5.7783e-01, -5.9616e-02,  2.1421e-01, -5.7352e-01,
         5.4328e-01,  1.0427e+00, -1.1309e+00,  9.4491e-01,  6.8133e-02,
        -1.7441e+00, -9.8989e-02, -9.2220e-01, -2.4651e+00,  1.1782e+00,
        -1.8370e+00,  4.3324e-01, -1.7124e-02, -7.6043e-02,  1.7841e+00,
        -2.4325e-01, -1.3105e+00,  2.1697e-01, -3.7794e-01,  1.6748e-01,
        -1.4395e+00,  1.1865e+00,  3.1526e-01, -2.4009e+00,  2.5103e-01,
         2.4584e-01,  3.1057e-01, -1.1921e+00, -1.2367e+00, -1.6780e+00,
        -5.6145e-01, -1.8833e-01,  2.0979e-01, -1.6315e+00,  1.5579e+00,
        -7.8442e-01,  2.3308e-01, -9.2106e-01, -9.3128e-01, -3.0092e-01,
         3.4052e-01,  1.2578e+00, -1.6218e+00, -2.0355e+00,  2.3375e+00,
        -1.7537e+00,  2.1962e+00, -7.9944e-01, -5.1961e-01,  2.1966e-02,
         6.2448e-01,  4.7352e-01,  1.2888e+00, -9.1179e-01, -6.3337e-01,
        -4.9299e-01, -4.1595e-01, -4.9596e-01, -7.0438e-02, -8.2897e-01,
        -6.2669e-01, -4.3211e-01, -1.0157e+00, -3.6239e-01,  7.1653e-01,
        -7.7576e-01,  4.7872e-02,  2.8539e-01, -1.5624e+00,  8.9652e-01,
        -1.2987e+00, -1.7620e-01,  5.0956e-01, -2.5484e+00,  6.9639e-01,
        -1.3313e+00,  1.7731e+00,  8.3339e-02, -4.7254e-01,  9.2492e-01,
        -1.1937e+00,  2.3216e-01,  6.2638e-01, -2.9821e-01, -9.5175e-01,
         1.3986e-01, -7.2828e-01,  9.3218e-01, -2.1498e-01,  1.9852e+00,
        -7.6348e-01, -1.0986e-01, -4.5050e-01,  8.9325e-01, -5.4250e-02,
        -2.2092e+00,  9.0535e-01, -4.7572e-01,  3.8196e+00,  5.3798e-01,
         6.7259e-01,  3.4446e-02,  9.9895e-01,  2.5469e-01,  5.2846e-01,
        -2.1099e-01,  6.2637e-01,  8.3245e-01, -5.4723e-01,  5.1626e-01,
        -2.8732e-01, -1.4760e+00, -6.0248e-01, -9.1217e-01, -1.0048e+00,
        -9.6727e-01,  6.8682e-01,  1.2641e+00,  1.1713e+00,  6.8342e-01,
        -1.1643e+00, -8.4251e-01,  1.0061e+00,  1.1720e+00,  8.1428e-01,
        -3.9032e-01, -1.6911e-01,  5.8808e-01, -7.8076e-01, -2.4616e-01,
        -1.6849e+00,  1.1659e+00, -1.8251e+00,  3.1699e-01,  1.5095e+00,
        -1.5060e+00,  3.6857e-01, -7.0881e-01,  5.0991e-02,  6.7777e-01,
         1.5028e+00,  2.6916e+00,  9.8157e-01,  1.1522e+00, -4.4316e-01,
        -2.2815e-01, -5.1983e-02, -1.6893e+00, -1.3268e+00,  1.1871e-01,
         1.3447e-01, -1.2886e+00, -2.1801e-01, -4.2458e-01, -3.3087e-01,
        -2.0921e-01,  3.2574e-01, -9.2461e-01,  2.6054e-02,  5.4284e-01,
        -1.2367e+00, -4.8402e-01, -2.6234e-01,  1.2544e+00, -1.1988e+00,
         8.6265e-01, -2.4934e+00,  3.2875e-01,  9.2096e-01,  2.0519e+00,
         5.9669e-01,  1.7006e+00,  1.2043e+00, -1.3831e-02,  8.2137e-01,
         2.5905e+00,  1.0919e+00,  6.3009e-01,  4.0482e-01, -1.9177e+00,
        -1.2173e+00, -7.8892e-01,  1.2347e-01,  4.9995e-01,  5.5700e-01,
        -1.2564e-01, -8.1435e-01, -5.6585e-01, -7.2065e-01, -2.9403e-02,
         5.6939e-01, -7.3482e-01,  4.1273e-01,  4.4252e-01,  3.6985e-02,
         4.5531e-01,  2.0893e-01, -7.5242e-01, -3.6504e-01, -1.3869e+00,
        -2.0901e-01,  2.7646e-01, -1.0748e+00,  9.7693e-01, -1.6419e+00,
         3.6965e-01, -1.7771e+00, -2.0028e+00,  5.6849e-01, -1.2572e+00,
        -9.8183e-01,  3.2700e-01, -1.1621e+00,  2.2477e+00,  2.2870e-01,
        -1.0811e+00,  1.7489e-01, -3.3924e-01, -1.1382e+00, -1.1581e+00,
        -4.9985e-01,  1.3037e+00,  8.4171e-01,  5.7023e-01, -1.0885e-01,
         9.9060e-01,  2.0441e-01, -1.3391e+00,  1.6382e-01, -7.7745e-01,
        -9.0423e-01, -1.5546e+00, -1.2986e-01,  2.8021e-01,  1.6554e+00,
        -8.1305e-01,  1.8922e-01,  7.8384e-01,  1.7426e-01, -9.5294e-01,
         9.4861e-01, -9.0602e-01, -1.3652e+00, -4.4656e-01,  2.2367e+00,
         2.0338e-01, -1.8825e+00, -1.7214e-01,  7.1797e-01,  9.0617e-01,
         1.8024e-01, -1.4715e-01, -9.1403e-01,  7.8615e-01,  4.2316e-01,
        -1.3731e+00,  2.6618e+00, -1.2546e+00, -7.6408e-01, -1.6649e-03,
         6.6941e-01,  9.2180e-01, -3.9468e-02, -8.3269e-01,  1.2527e+00,
         1.3417e+00, -3.0532e-01, -1.6798e+00,  1.1993e-01, -2.0846e+00,
        -1.4237e+00,  7.1825e-01,  1.4003e+00, -4.1155e-01, -1.3920e+00,
         8.8075e-01,  2.5274e-01, -6.7300e-01,  1.4368e+00, -1.0737e+00,
        -4.8454e-01,  1.0379e+00, -5.8462e-01, -1.1758e+00,  2.2032e-01,
        -6.0122e-01, -4.3992e-01,  1.7468e+00,  5.2493e-01, -1.9564e-02,
        -7.1687e-01,  8.6516e-01, -5.2663e-01,  5.6483e-02,  2.7542e-01,
         2.7912e-01, -2.6495e-02, -9.5285e-01, -1.3864e+00,  9.9803e-01,
        -8.2848e-01,  8.6416e-01,  1.2583e+00, -3.6157e-01, -1.8658e-01,
         5.3158e-01,  3.3203e-01,  6.7315e-02,  1.0896e-01,  2.7936e-01,
         9.0675e-01, -8.2908e-01,  3.1983e-01,  9.0398e-01, -4.6311e-01,
         3.5315e-01, -3.1564e-01, -1.0450e+00,  8.6836e-01, -8.8960e-01,
        -1.3014e+00, -8.3309e-01, -3.5340e-02, -1.6836e+00, -3.5367e-01,
        -3.7012e-01,  9.9008e-01], device='cuda:0'), tensor([-6.6767e-01, -5.7646e-02,  1.3565e+00, -5.0354e-01, -6.6323e-01,
         4.2017e-01, -2.4295e-01, -1.0157e-01,  4.6986e-01, -6.1501e-01,
         1.1567e+00, -3.0075e-01, -8.5670e-02,  5.2819e-01, -6.5581e-01,
         1.0505e+00, -3.3438e-01, -1.3841e+00, -8.2046e-01, -9.0455e-01,
         1.0299e+00, -7.0831e-01,  2.0814e-01,  1.0798e+00, -8.4214e-02,
         8.0716e-01,  9.0439e-02, -1.0130e+00, -1.4893e-01,  1.7855e+00,
        -2.4391e+00, -1.0759e+00,  2.1657e+00,  5.2804e-01,  9.1572e-01,
        -3.5939e-02, -3.2567e-01,  1.8311e+00,  2.1704e+00,  1.6349e+00,
        -1.8078e-02, -1.0913e+00, -9.8594e-03, -9.8924e-01, -1.0591e+00,
        -5.6357e-01,  1.2688e+00, -5.5107e-03,  1.1176e+00, -1.3105e+00,
        -5.5822e-01,  1.0501e+00,  1.0404e+00, -1.1342e-01,  2.9354e-01,
         9.3600e-01, -9.0754e-01,  2.3386e-01, -3.7286e-01, -1.9036e+00,
         2.0367e-01, -3.9186e-01, -7.6839e-02, -2.4304e-01, -1.7011e+00,
         1.8848e-01,  1.0916e+00,  6.3384e-01,  2.2600e+00,  1.1707e+00,
         9.2363e-01,  3.1978e-01,  1.6158e+00,  1.1923e+00, -5.1646e-01,
         8.7766e-01,  8.2237e-01,  1.0471e+00, -2.5357e+00, -5.8972e-01,
        -2.2950e-01,  1.9660e+00,  1.1024e+00,  9.1609e-02,  2.1419e+00,
         8.1431e-01, -4.8656e-01, -7.5956e-02,  4.0068e-01, -6.0208e-01,
        -5.8016e-01, -8.3305e-01,  3.7636e-02, -1.0388e-01,  1.4589e+00,
         1.4573e+00, -1.2138e+00,  4.7063e-01, -1.6900e+00,  3.0377e-02,
         1.5244e+00, -2.7403e+00,  1.2126e+00,  4.4050e-01,  9.6501e-02,
         5.6082e-01, -1.3961e+00, -8.3201e-01,  3.6328e-02, -3.7352e-01,
         5.0482e-01,  7.8123e-01,  5.0126e-01, -1.0457e-01,  7.9323e-01,
        -1.8541e+00,  1.1928e+00,  6.1304e-01,  4.0197e-02, -8.0014e-01,
         1.4284e+00,  2.5383e+00, -1.4152e+00,  1.2843e+00,  3.8021e-01,
         4.4228e-01, -2.4221e-01,  3.1517e-01,  1.0087e+00, -1.1564e+00,
        -6.6309e-01, -1.3386e-01,  4.8712e-01, -1.0697e+00, -2.7909e-01,
        -1.2332e+00, -4.1780e-01, -1.0828e+00,  4.7444e-01, -9.5583e-01,
        -7.6926e-01, -4.0158e-02, -1.5411e+00, -9.7320e-02, -9.2787e-01,
         9.6384e-01,  7.5068e-01, -8.3493e-01, -8.8207e-01,  2.6231e-01,
         4.6214e-01,  1.3831e+00,  1.6701e+00, -1.0793e+00,  1.8534e+00,
        -7.6416e-01, -1.4705e+00,  1.2739e-01,  4.1090e-01, -5.4833e-01,
        -1.8407e+00,  2.0735e-01,  1.3397e+00, -1.5620e+00,  6.9707e-01,
         4.9436e-02, -1.7875e+00, -2.8346e-01, -8.9905e-01,  1.1056e+00,
         1.8029e-02,  3.3304e-01, -2.2753e+00,  4.3529e-01,  9.7364e-02,
         1.6514e-01, -1.0646e+00,  2.6381e-02, -1.2511e+00,  1.8966e-01,
        -1.0003e+00, -8.3720e-01,  1.3760e+00,  1.9247e-01,  5.4224e-01,
         2.0977e-01, -1.4022e+00, -1.1448e+00, -1.4053e+00,  9.2049e-01,
         2.1746e-01,  9.2170e-02, -4.0428e-01,  9.8942e-01, -7.0087e-01,
        -2.2281e-01,  1.2003e+00, -1.1565e+00, -7.9928e-01, -8.8279e-01,
        -1.6068e-01, -1.0650e+00, -3.1858e-01,  5.7948e-01,  1.2504e+00,
        -6.4016e-01,  2.1789e-01, -1.7705e+00,  2.7913e-01, -7.8323e-02,
         1.3323e+00, -1.5233e+00,  4.9594e-01, -2.0200e+00, -1.5383e+00,
         1.0946e+00,  4.4705e-01,  1.2157e+00, -2.2917e-01, -9.2577e-03,
        -9.7840e-01, -8.7997e-01,  6.9543e-01,  5.1217e-01, -3.7736e-01,
        -6.1710e-01, -5.3119e-01, -8.3445e-01,  1.3310e-01, -8.8080e-01,
        -2.0019e-02,  5.5784e-02,  4.1146e-02, -1.2130e+00,  3.9147e-01,
         4.3127e-01,  5.1102e-01,  1.7138e+00,  4.2608e-01,  1.7685e-02,
        -1.0015e-02,  3.6454e-01, -4.3323e-01,  9.7540e-01,  5.2498e-01,
         9.3650e-01,  8.6043e-01,  7.4651e-01,  3.1342e-01,  6.5766e-01,
         8.3368e-01,  1.7694e-01,  1.5484e-01, -4.2745e-01,  7.8894e-01,
         1.0377e+00, -9.8748e-01,  7.7732e-01,  1.9739e-02,  6.7186e-01,
         8.1707e-01, -1.9044e+00, -2.4544e-01, -7.6392e-02, -3.6193e-01,
        -3.8769e-01, -1.9584e+00, -1.1531e+00,  2.3126e+00, -9.2154e-01,
         1.7702e+00,  3.9169e-01, -1.3876e+00,  9.5252e-01, -1.4290e-01,
        -6.9848e-01, -5.2199e-01,  1.1329e+00,  1.3222e+00, -1.1200e+00,
        -9.9116e-01, -1.0120e+00, -2.5778e+00,  6.4410e-01, -3.8283e-01,
         2.1016e-01,  1.8963e-02, -9.7066e-01,  2.8785e-01,  4.7774e-01,
         4.8822e-01,  5.4780e-02, -6.5914e-01,  3.2605e-01, -1.0491e+00,
        -6.4101e-01,  2.3300e+00, -6.2596e-01, -1.8868e+00, -4.3973e-01,
        -8.8484e-01, -1.6658e-01,  6.4850e-01,  2.1093e+00, -1.4774e+00,
        -1.2087e+00,  6.3067e-01,  4.6883e-01, -5.3408e-01, -7.7926e-01,
         9.5824e-02, -7.6440e-01,  9.6985e-01, -2.5359e+00, -6.8982e-02,
         1.0044e+00,  2.4409e-02,  1.4610e+00, -1.7235e-01,  1.0690e+00,
         1.3793e-01,  1.8142e+00,  7.1423e-01, -2.2346e-01, -1.2009e+00,
         7.3666e-01, -3.7086e-01,  2.4558e-01,  2.5531e+00, -2.4000e-01,
        -5.4846e-02, -2.6816e-01,  3.7795e-01,  1.1430e-01, -4.6881e-01,
        -1.1657e+00, -4.8808e-01, -2.6580e-01,  1.7069e+00, -2.7865e-02,
        -3.7453e-01, -2.6121e-01, -1.3830e+00, -4.9255e-01, -2.2941e-01,
        -5.9620e-01,  8.1061e-01,  6.9998e-01,  1.2715e+00,  7.3079e-01,
        -5.6427e-01, -1.2220e+00,  3.4832e-02,  3.2086e-01, -5.4313e-01,
         1.7084e+00,  7.9149e-01,  6.5261e-01, -2.5826e-02, -2.2327e-01,
         2.3669e+00, -5.3408e-01, -1.5230e-01, -9.2334e-01,  1.3326e+00,
        -7.6369e-02,  1.2216e+00,  5.5947e-01,  4.1221e-01, -5.6956e-01,
         9.4459e-01,  7.7669e-01, -1.7908e+00,  1.0233e+00,  4.8211e-04,
        -2.3527e+00, -6.7877e-01, -1.5977e-01, -3.7327e-01, -6.1969e-01,
         8.5646e-02, -1.0169e+00,  2.7507e-01, -2.5113e-01, -4.6179e-01,
        -6.4008e-01, -7.7968e-01, -6.8040e-01, -2.7325e+00, -1.6874e+00,
         1.8234e-01, -1.6864e-01, -1.6708e-01,  1.4814e+00,  2.1145e-01,
         3.4090e-01,  5.2920e-02,  7.6082e-01,  8.7309e-01,  1.8432e-01,
        -1.3489e-01, -1.0940e+00, -3.1579e-01, -3.9033e-01,  2.0510e-01,
         7.4155e-01, -1.1795e+00,  6.3909e-01, -6.5035e-01, -4.3123e-02,
         2.0314e-01,  1.2447e+00,  4.1114e-01, -3.9587e-01, -2.9633e-01,
         1.1304e+00,  1.4137e+00, -9.8445e-01,  6.7248e-01,  4.4833e-01,
        -7.1536e-01, -8.8125e-01, -7.5365e-01, -7.5274e-01, -1.0520e+00,
         3.4893e-01, -5.7328e-01, -8.8626e-01, -1.6148e+00, -1.0020e+00,
         1.6210e+00,  2.4201e-01, -1.2357e+00, -9.6524e-02,  1.0691e-02,
         6.9678e-01, -6.5331e-01, -3.7209e-01, -2.1802e-01,  1.0739e+00,
        -8.0581e-01, -9.2653e-02, -1.3553e+00, -1.2789e+00, -1.7694e+00,
         2.0759e-01, -1.2419e+00, -1.2623e+00,  7.6426e-01,  1.0490e+00,
        -1.1550e-01, -4.6467e-01,  2.8713e-01,  5.3597e-01,  6.2920e-01,
        -6.6733e-01,  5.9183e-01,  5.9663e-01, -3.1872e-01, -1.2496e+00,
        -1.1879e+00,  1.7444e+00,  1.3125e+00, -9.3437e-01,  2.2411e+00,
        -4.1821e-01, -1.2039e+00, -1.9324e+00,  3.2658e-01, -3.2100e-01,
         1.6672e-01, -4.3496e-01,  2.7492e-01,  1.8384e-01,  9.4644e-01,
        -1.0046e+00, -5.4279e-02,  8.0669e-01,  9.2378e-01, -1.6853e+00,
        -9.0038e-01, -1.4592e-02, -6.4680e-01, -9.2415e-01,  2.1867e+00,
        -9.8426e-01, -1.1788e+00, -1.1580e+00, -1.1369e+00, -4.0604e-01,
         6.0829e-01, -1.6881e-01, -6.5919e-01, -3.1495e-01,  1.2155e+00,
         1.1048e+00,  1.0094e+00, -5.0273e-01, -4.3230e-02, -7.8570e-01,
        -1.8852e-01,  2.1055e-01,  2.4077e-01, -1.1684e+00,  8.6620e-01,
        -1.4574e-01, -1.4324e+00, -1.9516e+00,  3.7096e-01, -5.9363e-01,
         2.2167e-01, -1.2148e+00], device='cuda:0')]; kwargs: k=2