In [1]:
import torch
from tqdm import trange
import numpy as np
import scipy
import sys
torch.set_grad_enabled(False)
import torch.nn as nn
import math
import time 

### Recursive decomposition

In [2]:
torch.manual_seed(0)

def fft_matrix(N):
    n = torch.arange(N)
    k = n.view(-1, 1)
    M = torch.exp(-2j * torch.pi * n * k / N)
    return M

def compute_twiddle_factors_fft(n, m):
    """Compute the twiddle factors of size n x m"""
    n_a = torch.arange(n).view(-1, 1)
    m_a = torch.arange(m)
    N = n * m
    M = torch.exp(-2j * torch.pi * n_a * m_a / N)
    return M

def ifft_matrix(N):
    n = torch.arange(N)
    k = n.view(-1, 1)
    M = torch.exp(2j * torch.pi * n * k / N)
    return M

def compute_twiddle_factors_ifft(n, m):
    """Compute the twiddle factors of size n x m"""
    n_a = torch.arange(n).view(-1, 1)
    m_a = torch.arange(m)
    N = n * m
    M = torch.exp(2j * torch.pi * n_a * m_a / N)
    return M

def monarch_conv(x, k_f, f_sqrt_N_fft, twiddle_factors_fft, f_sqrt_N_ifft, twiddle_factors_ifft, N, sqrt_N):
    '''
    x: (B, H, N)
    k_f: (H, N)
    f_sqrt_N_fft: (sqrt_N, sqrt_N)
    twiddle_factors_fft: (sqrt_N, sqrt_N)
    f_sqrt_N_ifft: (sqrt_N, sqrt_N)
    twiddle_factors_ifft: (sqrt_N, sqrt_N)
    N: 16K
    sqrt_N: 32
    '''
    B, H, _, N = x.shape

    # compute the FFT
    x = x.reshape(B, H, _, sqrt_N, sqrt_N)
    x = x.transpose(-1, -2)
    x = x @ f_sqrt_N_fft
    x = x.transpose(-1, -2)
    x = x * twiddle_factors_fft # (H, sqrt_N, sqrt_N) * (sqrt_N, sqrt_N), pointwise
    x = x @ f_sqrt_N_fft
    # x = x.transpose(-1, -2)

    print(f"\nIn the inner function:")
    print(f"{x.shape=}; {twiddle_factors_fft.shape=}")
    print(f"{f_sqrt_N_fft.shape=}; {twiddle_factors_fft.shape=}")
    print(f"End of inner function\n")

    # pointwise multiplication 
    k_f = k_f.reshape(H, _, sqrt_N, sqrt_N) # to match the shape of x
    x = x * k_f

    # compute the IFFT
    # x = x.transpose(-1, -2)
    x = x @ f_sqrt_N_ifft
    x = x.transpose(-1, -2)
    x = x * twiddle_factors_ifft # (H, sqrt_N, sqrt_N) * (sqrt_N, sqrt_N), pointwise
    x = x @ f_sqrt_N_ifft
    x = x.transpose(-1, -2) # necessary to complete the ifft

    x = x.reshape(B, H, _, N)
    return x


def monarch_conv_for_loop(x, k_f, f_sqrt_N_fft, twiddle_factors_fft, f_sqrt_N_ifft, twiddle_factors_ifft, N, sqrt_N):
    B, H, _, N = x.shape

    print(x.shape)
    x = x.reshape(B, H, _, sqrt_N, sqrt_N)
    k_f = k_f.reshape(H, _, sqrt_N, sqrt_N)

    x2 = x.clone()
    chunk_size = 4  # Example chunk size (adjust based on requirements)
    chunks = x.shape[2] // chunk_size
    print(f"{chunks=}")
    for i in range(chunks):
        # Take a block of 'chunk_size' slices at a time
        block = x[:, :, i*chunk_size:(i+1)*chunk_size].transpose(-1, -2)
        block = block @ f_sqrt_N_fft  # Apply FFT
        block = block.transpose(-1, -2)
        block = block * twiddle_factors_fft
        block = block @ f_sqrt_N_fft  # Apply FFT again
        x2[:, :, i*chunk_size:(i+1)*chunk_size] = block

        # pointwise multiplication
        x2[:, :, i*chunk_size:(i+1)*chunk_size] = x2[:, :, i*chunk_size:(i+1)*chunk_size] * k_f[:, i*chunk_size:(i+1)*chunk_size]

        # Now apply the IFFT
        block = x2[:, :, i*chunk_size:(i+1)*chunk_size]
        block = block @ f_sqrt_N_ifft  # Apply iFFT
        block = block.transpose(-1, -2)
        block = block * twiddle_factors_ifft  # Element-wise multiplication with iFFT twiddle factors
        block = block @ f_sqrt_N_ifft  # Apply second iFFT
        block = block.transpose(-1, -2)  # Final transpose to restore shape
        x2[:, :, i*chunk_size:(i+1)*chunk_size] = block

    x2 = x2.reshape(B, H, _, N)
    return x2
    

def monarch_conv_full(
    x, k_f, 
    f_32_fft, 
    twiddle_factors_fft, twiddle_factors_32_fft,
    f_32_ifft, 
    twiddle_factors_ifft, twiddle_factors_32_ifft,
    N, sqrt_N_32, N_1024):
    '''
    x: (B, H, N)
    k_f: (H, N)
    f_32_fft: (32, 32)
    twiddle_factors_fft: (32, 1024)
    f_32_ifft: (32, 32)
    twiddle_factors_ifft: (32, 1024)
    N: 32768
    '''
    B, H, N = x.shape

    # compute the FFT
    x = x.reshape(B, H, sqrt_N_32, N_1024)
    x = x.transpose(-1, -2) # ... 1K, 16
    x = x @ f_32_fft    # ... 1K, 16
    x = x.transpose(-1, -2) # ... 16, 1K
    x = x * twiddle_factors_fft # (H, sqrt_N, 1K) * (16, 1K), pointwise

    x = monarch_conv_for_loop(
        x, k_f, f_32_fft, twiddle_factors_32_fft,
        f_32_ifft, twiddle_factors_32_ifft, N_1024, 32)

    x = x * twiddle_factors_ifft # ... 16, 1K
    x = x.transpose(-1, -2) # ... 1K, 16
    x = x @ f_32_ifft   # ... 1K, 16
    x = x.transpose(-1, -2) # ... 16, 1K

    x = x.reshape(B, H, N)
    return x

def ref_conv(x, k_f):
    x = torch.fft.fft(x)
    x = x * k_f
    x = torch.fft.ifft(x)
    return x


In [3]:
# inputs to the code
B = 16
H = 768
N = 32 * 32 * 32
sqrt_N_1 = 32
sqrt_N_2 = 32
device = 'cuda' # for laptop
x = torch.randn(B, H, N, device=device).to(torch.cfloat) * 0.2

k = torch.randn(H, N, device=device) * 0.02
k_f = torch.fft.fft(k).to(torch.cfloat)
f_32_fft = fft_matrix(sqrt_N_1).to(x.device)
f_32_ifft = ifft_matrix(sqrt_N_1).to(x.device)
twiddle_factors_fft_32_1K = compute_twiddle_factors_fft(sqrt_N_1, 1024).to(x.device)
twiddle_factors_fft_32_32 = compute_twiddle_factors_fft(sqrt_N_2, sqrt_N_2).to(x.device)
twiddle_factors_ifft_32_1K = compute_twiddle_factors_ifft(sqrt_N_1, 1024).to(x.device) / N
twiddle_factors_ifft_32_32 = compute_twiddle_factors_ifft(sqrt_N_2, sqrt_N_2).to(x.device)
k_f_permuted = k_f.reshape(H, 1024, sqrt_N_1).transpose(-1, -2).reshape(H, sqrt_N_1, sqrt_N_2, sqrt_N_2).transpose(-1, -2).reshape(H, N).contiguous()

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB. GPU 

In [5]:
out = monarch_conv_full(
    x, k_f_permuted, 
    f_32_fft,  
    twiddle_factors_fft_32_1K, twiddle_factors_fft_32_32,
    f_32_ifft, 
    twiddle_factors_ifft_32_1K, twiddle_factors_ifft_32_32,
    N, 32, 1024
)
out = out.real

out_ref = ref_conv(x, k_f)
out_ref = out_ref.real
print(torch.max(torch.abs(out - out_ref)))

torch.Size([16, 768, 32, 1024])
chunks=8
tensor(3.4034e-05, device='cuda:0')


In [6]:
# reference
out_ref = ref_conv(x, k_f)
out_ref = out_ref.real

# ours
B, H, N = x.shape
sqrt_N_32 = 32
N_1024 = 1024

x = x.reshape(B, H, sqrt_N_32, N_1024)
x = x.transpose(-1, -2) # ... 1K, 16
x = x @ f_32_fft    # ... 1K, 16
x = x.transpose(-1, -2) # ... 16, 1K
print(f"{x.shape=}, {twiddle_factors_ifft_32_1K.shape=}")
x = x * twiddle_factors_fft_32_1K # (H, sqrt_N, 1K) * (16, 1K), pointwise
# x = x @ f_256_fft       # ... 16, 1K

x = monarch_conv(
        x, k_f_permuted, f_32_fft, twiddle_factors_fft_32_32,
        f_32_ifft, twiddle_factors_ifft_32_32, N_1024, 32)

x = x * twiddle_factors_ifft_32_1K # ... 16, 1K
x = x.transpose(-1, -2) # ... 1K, 16
x = x @ f_32_ifft   # ... 1K, 16
x = x.transpose(-1, -2) # ... 16, 1K
x = x.reshape(B, H, N)
x = x.real

print(torch.max(torch.abs(x - out_ref)))

x.shape=torch.Size([16, 768, 32, 1024]), twiddle_factors_ifft_32_1K.shape=torch.Size([32, 1024])

In the inner function:
x.shape=torch.Size([16, 768, 32, 32, 32]); twiddle_factors_fft.shape=torch.Size([32, 32])
f_sqrt_N_fft.shape=torch.Size([32, 32]); twiddle_factors_fft.shape=torch.Size([32, 32])
End of inner function



tensor(3.4034e-05, device='cuda:0')


### Blocked PyTorch Code

In [7]:
f_32_fft

tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  ...,
          1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [ 1.0000+0.0000j,  0.9808-0.1951j,  0.9239-0.3827j,  ...,
          0.8315+0.5556j,  0.9239+0.3827j,  0.9808+0.1951j],
        [ 1.0000+0.0000j,  0.9239-0.3827j,  0.7071-0.7071j,  ...,
          0.3827+0.9239j,  0.7071+0.7071j,  0.9239+0.3827j],
        ...,
        [ 1.0000+0.0000j,  0.8315+0.5556j,  0.3827+0.9239j,  ...,
         -0.1951-0.9808j,  0.3827-0.9239j,  0.8315-0.5556j],
        [ 1.0000+0.0000j,  0.9239+0.3827j,  0.7071+0.7071j,  ...,
          0.3827-0.9239j,  0.7071-0.7071j,  0.9239-0.3827j],
        [ 1.0000+0.0000j,  0.9808+0.1951j,  0.9239+0.3827j,  ...,
          0.8315-0.5556j,  0.9239-0.3827j,  0.9808-0.1951j]], device='cuda:0')

In [24]:
# inputs to the code
B = 16
H = 768
N = 32 * 32 * 32
sqrt_N_1 = 32
sqrt_N_2 = 32
device = 'cuda' # for laptop
x = torch.randn(B, H, N, device=device).to(torch.cfloat) * 0.2

k = torch.randn(H, N, device=device) * 0.02
k_f = torch.fft.fft(k).to(torch.cfloat)
f_32_fft = fft_matrix(sqrt_N_1).to(x.device)
f_32_ifft = ifft_matrix(sqrt_N_1).to(x.device)
twiddle_factors_fft_32_1K = compute_twiddle_factors_fft(sqrt_N_1, 1024).to(x.device)
twiddle_factors_fft_32_32 = compute_twiddle_factors_fft(sqrt_N_2, sqrt_N_2).to(x.device)
twiddle_factors_ifft_32_1K = compute_twiddle_factors_ifft(sqrt_N_1, 1024).to(x.device) / N
twiddle_factors_ifft_32_32 = compute_twiddle_factors_ifft(sqrt_N_2, sqrt_N_2).to(x.device)
k_f_permuted = k_f.reshape(H, 1024, sqrt_N_1).transpose(-1, -2).reshape(H, sqrt_N_1, sqrt_N_2, sqrt_N_2).transpose(-1, -2).reshape(H, N).contiguous()


x2 = x.clone()
x3 = x.clone()
base_case = ref_conv(x2, k_f)

In [25]:
# compute fft - stage 1
x = x.reshape(B, H, 32, 1024)   # ... 32, 1K
x = x.transpose(-1, -2) # ... 1K, 32
x = x @ f_32_fft    # ... 1K, 32
print(x[0,0,:,0])   
print(f"{x.shape=}")
x = x.transpose(-1, -2) # ... 32, 1K
x = x * twiddle_factors_fft_32_1K # (H, sqrt_N, 1K) * (16, 1K), pointwise
# store

# compute fft - stage 2
x = x.reshape(B, H, 32, 32, 32)
x = x.transpose(-1, -2)
x = x @ f_32_fft
x = x.transpose(-1, -2)
x = x * twiddle_factors_fft_32_32 # (H, sqrt_N, sqrt_N) * (sqrt_N, sqrt_N), pointwise
x = x @ f_32_fft

# # pointwise multiplication 
k_f = k_f_permuted.reshape(H, 32, 32, 32) # to match the shape of x
x = x * k_f

# compute ifft -- stage 2
x = x @ f_32_ifft
x = x.transpose(-1, -2)
x = x * twiddle_factors_ifft_32_32 # (H, sqrt_N, sqrt_N) * (sqrt_N, sqrt_N), pointwise
x = x @ f_32_ifft
x = x.transpose(-1, -2) # necessary to complete the ifft

x = x.reshape(B, H, 32, 1024)

# compute ifft -- stage 1
x = x * twiddle_factors_ifft_32_1K # ... 16, 1K
x = x.transpose(-1, -2) # ... 1K, 16
x = x @ f_32_ifft   # ... 1K, 16
x = x.transpose(-1, -2) # ... 16, 1K
x = x.reshape(B, H, N)

print(torch.max(torch.abs(x.real - base_case.real)))

tensor([ 0.5190+0.j, -1.9657+0.j,  0.7759+0.j,  ...,  0.7705+0.j,  0.4249+0.j,
         0.1022+0.j], device='cuda:0')
x.shape=torch.Size([16, 768, 1024, 32])
tensor(3.0637e-05, device='cuda:0')


In [26]:
######### BLOCKED IMPLEMENTATION #########

x2 = x3.reshape(B, H, 32, 1024)
print(f"{k_f.shape=}")
chunk_size = 32 
chunks = 32 
for i in range(chunks):
    block = x2[:, :, :, i*chunk_size:(i+1)*chunk_size]
    block = block.transpose(-1, -2) 
    # block = block @ f_32_fft 
    block = block.transpose(-1, -2)
    block = block * twiddle_factors_fft_32_1K[:, i*chunk_size:(i+1)*chunk_size]
    print(block.real)
    x2[:, :, :, i*chunk_size:(i+1)*chunk_size] = block


x2 = x2.reshape(B, H, 32, 32, 32)
chunk_size = 4 
chunks = x2.shape[2] // chunk_size
print(f"{chunks=}; {x2.shape=}; {chunk_size=}")
for i in range(chunks):
    block = x2[:, :, i*chunk_size:(i+1)*chunk_size]
    block = block.transpose(-1, -2)
    block = block @ f_32_fft  # Apply FFT
    block = block.transpose(-1, -2)
    block = block * twiddle_factors_fft_32_32
    block = block @ f_32_fft  # Apply FFT again
    x2[:, :, i*chunk_size:(i+1)*chunk_size] = block

    # # pointwise multiplication
    x2[:, :, i*chunk_size:(i+1)*chunk_size] = x2[:, :, i*chunk_size:(i+1)*chunk_size] * k_f[:, i*chunk_size:(i+1)*chunk_size]

    # Now apply the IFFT
    block = x2[:, :, i*chunk_size:(i+1)*chunk_size]
    block = block @ f_32_ifft  # Apply iFFT
    block = block.transpose(-1, -2)
    block = block * twiddle_factors_ifft_32_32 
    block = block @ f_32_ifft
    block = block.transpose(-1, -2) 
    x2[:, :, i*chunk_size:(i+1)*chunk_size] = block
    
x2 = x2.reshape(B, H, 32, 1024)


# compute ifft
chunk_size = 256 
chunks = 1024 // chunk_size
for i in range(chunks):
    block = x2[:, :, :, i*chunk_size:(i+1)*chunk_size] 
    block = block * twiddle_factors_ifft_32_1K[:, i*chunk_size:(i+1)*chunk_size] 
    block = block.transpose(-1, -2) 
    block = block @ f_32_ifft  
    block = block.transpose(-1, -2)
    x2[:, :, :, i*chunk_size:(i+1)*chunk_size] = block

x2 = x2.reshape(B, H, N)
print(x.shape)
print(x2.shape)

print(x[0,0,:4].real)
print(x2[0,0,:4].real)
all_close = torch.allclose(x, x2, atol=1e-2)  
print(torch.max(torch.abs(x.real - x2.real))) 
print(all_close)
all_close = torch.allclose(x2, base_case, atol=1e-2)
print(all_close)

k_f.shape=torch.Size([768, 32, 32, 32])
tensor([[[[-1.6281e-01, -1.7063e-02, -1.0231e-01,  ...,  8.2185e-02,
            1.0062e-01, -7.8620e-02],
          [-3.3471e-03, -3.1587e-01,  3.0419e-02,  ..., -1.6278e-01,
            1.2327e-01,  3.4597e-01],
          [ 4.2181e-02, -2.4799e-01, -2.5065e-01,  ..., -1.3415e-01,
            1.7996e-01,  1.1810e-01],
          ...,
          [ 1.7995e-01,  3.1803e-01,  4.9548e-01,  ...,  2.9793e-02,
           -1.3749e-01, -4.8869e-01],
          [-1.1798e-01, -2.7385e-01,  3.1999e-02,  ...,  2.8935e-02,
           -4.6507e-02, -1.8730e-02],
          [ 3.8452e-03,  5.1120e-02, -1.4842e-02,  ..., -4.1960e-01,
           -2.4970e-01, -2.9385e-02]],

         [[-1.0266e-01, -6.7748e-03, -2.3231e-01,  ..., -1.3138e-01,
            5.6204e-02,  1.6878e-02],
          [-1.5261e-01, -1.3247e-01,  4.5678e-02,  ..., -3.1211e-01,
           -1.1327e-01,  2.5172e-01],
          [ 1.5850e-01, -4.3098e-02, -2.7288e-02,  ..., -2.9472e-02,
           -7.0143

In [65]:
x = x3.clone()  # B, H, 32, 1024
x2 = x3.clone()
x = x.reshape(B, H, 32, 1024)
x2 = x2.reshape(B, H, 32, 1024)

for i in range(32):
    x[:, :, i, :] = torch.arange(1024)
    x2[:, :, i, :] = torch.arange(1024)


x = x.reshape(B, H, 32, 32, 32)
x = x.transpose(-1, -2)
x = x @ f_32_fft
x = x.transpose(-1, -2)
x = x * twiddle_factors_fft_32_32 
x = x @ f_32_fft


print(f"{x2[0,0,:,0].real}")
x2 = x2.reshape(B, H, 32, 32, 32)
chunk_size = 16
chunks = x2.shape[2] // chunk_size
for i in range(chunks):
    block = x2[:, :, i*chunk_size:(i+1)*chunk_size]
    print(f"{block[0,0,0].real}")
    print(f"{block.shape=}")
    block = block.transpose(-1, -2)
    block = block @ f_32_fft
    block = block.transpose(-1, -2)
    block = block * twiddle_factors_fft_32_32
    block = block @ f_32_fft 
    x2[:, :, i*chunk_size:(i+1)*chunk_size] = block
    break
# all_close = torch.allclose(x, x2, atol=1e-2)
# print(torch.max(torch.abs(x.real - x2.real)))
# print(all_close)


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([[0.0000e+00, 1.0000e+00, 2.0000e+00,  ..., 2.9000e+01, 3.0000e+01,
         3.1000e+01],
        [3.2000e+01, 3.3000e+01, 3.4000e+01,  ..., 6.1000e+01, 6.2000e+01,
         6.3000e+01],
        [6.4000e+01, 6.5000e+01, 6.6000e+01,  ..., 9.3000e+01, 9.4000e+01,
         9.5000e+01],
        ...,
        [9.2800e+02, 9.2900e+02, 9.3000e+02,  ..., 9.5700e+02, 9.5800e+02,
         9.5900e+02],
        [9.6000e+02, 9.6100e+02, 9.6200e+02,  ..., 9.8900e+02, 9.9000e+02,
         9.9100e+02],
        [9.9200e+02, 9.9300e+02, 9.9400e+02,  ..., 1.0210e+03, 1.0220e+03,
         1.0230e+03]], device='cuda:0')
block.shape=torch.Size([16, 768, 16, 32, 32])


### Real only

In [3]:
def fft_matrix(N):
    n = torch.arange(N)
    k = n.view(-1, 1)
    M = torch.exp(-2j * torch.pi * n * k / N).real
    return M

def compute_twiddle_factors_fft(n, m):
    """Compute the twiddle factors of size n x m"""
    n_a = torch.arange(n).view(-1, 1)
    m_a = torch.arange(m)
    N = n * m
    M = torch.exp(-2j * torch.pi * n_a * m_a / N).real
    return M

def ifft_matrix(N):
    n = torch.arange(N)
    k = n.view(-1, 1)
    M = torch.exp(2j * torch.pi * n * k / N).real
    return M

def compute_twiddle_factors_ifft(n, m):
    """Compute the twiddle factors of size n x m"""
    n_a = torch.arange(n).view(-1, 1)
    m_a = torch.arange(m)
    N = n * m
    M = torch.exp(2j * torch.pi * n_a * m_a / N).real
    return M / N


from torch.nn import functional as F
def ref_fftconv_test(u, k, N):
    L = u.shape[-1]
    u_f = torch.fft.fft(u, n = N).real
    k_f = torch.fft.fft(k, n = N).real
    y_f = u_f * k_f
    y = torch.fft.ifft(y_f, n = N).real[..., :L].to(u.dtype).contiguous()
    return y

def pytorch_test(u, k):
    ############# GET THE INPUTS #############
    f_mat = fft_matrix(N1)
    finv_mat = ifft_matrix(N1)
    
    # Normalization factor to make IFFT exact inverse of FFT
    twiddle_factors_fft = compute_twiddle_factors_fft(N1, N1).to(u.device)
    twiddle_factors_ifft = compute_twiddle_factors_ifft(N1, N1).to(u.device)

    k_f = torch.fft.fft(k, n = N).real

    ############# COMPUTE OUTPUT #############
    # step 1. FFT(U) using FFT matrix
    # compute the FFT
    x = u.reshape(B, H, N1, N1)
    x = x.transpose(-1, -2)
    x = torch.einsum('...i,ij->...j', x, f_mat)
    x = x.transpose(-1, -2)
    x = x * twiddle_factors_fft # (H, sqrt_N, sqrt_N) * (sqrt_N, sqrt_N), pointwise
    x = torch.einsum('...i,ij->...j', x, f_mat)
    # x = x.transpose(-1, -2)

    # pointwise multiplication 
    k_f = k_f.reshape(H, N1, N1).transpose(1,2) # to match the shape of x
    x = x * k_f

    # compute the IFFT
    # x = x.transpose(-1, -2)
    x = x @ finv_mat
    x = x.transpose(-1, -2)
    x = x * twiddle_factors_ifft # (H, sqrt_N, sqrt_N) * (sqrt_N, sqrt_N), pointwise
    x = x @ finv_mat
    x = x.transpose(-1, -2) # necessary to complete the ifft

    x = x.reshape(B, H, N)
    return x

print(f"{u.shape=}, {k.shape=}")
V_real = pytorch_test(u, k)
V_test = ref_fftconv_test(u, k, N) 
print(V_test[0, 0:6, :4])
print(V_real[0, 0:6, :4])
print(torch.allclose(V_real, V_test, atol=1e-3))

u.shape=torch.Size([16, 64, 4]), k.shape=torch.Size([64, 4])
tensor([[-0.4501, -0.5298, -0.4254, -0.5298],
        [ 0.7841,  1.0304,  0.2119,  1.0304],
        [ 0.7418,  0.1688,  0.4411,  0.1688],
        [-0.1613, -0.0828, -0.1343, -0.0828],
        [-0.1060, -1.5650, -2.9689, -1.5650],
        [ 0.1471, -0.3151,  0.2748, -0.3151]])
tensor([[-0.4501, -0.5298, -0.4254, -0.5298],
        [ 0.7841,  1.0304,  0.2119,  1.0304],
        [ 0.7418,  0.1688,  0.4411,  0.1688],
        [-0.1613, -0.0828, -0.1343, -0.0828],
        [-0.1060, -1.5650, -2.9689, -1.5650],
        [ 0.1471, -0.3151,  0.2748, -0.3151]])
True
