Toy NB 

Input : variable latente z de dimension Dz

Output : 

moyenne mu_x de dimension Dx
  
covariance sigma_x (Dx, Dx) définie positive et non diagonale
  
Doit scaler avec batch, sequence_length :

(Dz) => (Dx) + (Dx,Dx)

(L,Dz) => (L,Dx) + (L,Dx,Dx)

(B,L,Dz) => (B,L,Dx) + (B,L,Dx,Dx)

In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.distributions.multivariate_normal import MultivariateNormal as MVN

In [2]:
Dz = 4
Dx = 16
B = 32
N = 50

In [3]:
# diagonale de la matrice de covariance : pas de souci.

diagonale = nn.Linear(Dz, Dx)

z = torch.randn(Dz)
x = diagonale(z)
print(f"z shape: {z.shape} => x shape: {x.shape}")

z = torch.randn(N,Dz)
x = diagonale(z)
print(f"z shape: {z.shape} => x shape: {x.shape}")

z = torch.randn(B,N,Dz)
x = diagonale(z)
print(f"z shape: {z.shape} => x shape: {x.shape}")


z shape: torch.Size([4]) => x shape: torch.Size([16])
z shape: torch.Size([50, 4]) => x shape: torch.Size([50, 16])
z shape: torch.Size([32, 50, 4]) => x shape: torch.Size([32, 50, 16])


In [4]:
# matric triangulaire inferieure de la matrice de covariance

class TriangularLower(nn.Module):
    def __init__(self, Dz=Dz, Dx=Dx):
        super(TriangularLower, self).__init__()
        self.full = nn.Sequential(
            nn.Linear(Dz, Dx * Dx),
            nn.Unflatten(-1, (Dx, Dx)),
            )

    def forward(self, x):
        return torch.tril(self.full(x))
    
triangulaire = TriangularLower()

In [5]:
z = torch.randn(Dz)
L = triangulaire(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z = torch.randn(N,Dz)
L = triangulaire(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z = torch.randn(B,N,Dz)
L = triangulaire(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z shape: torch.Size([4]) => L shape: torch.Size([16, 16])
z shape: torch.Size([50, 4]) => L shape: torch.Size([50, 16, 16])
z shape: torch.Size([32, 50, 4]) => L shape: torch.Size([32, 50, 16, 16])


In [6]:
print(L[4, 3])
print(L.transpose(-1, -2)[4, 3])

tensor([[-0.7918,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.8348, -1.2289,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.0379, -1.9188,  2.2531,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0453,  1.0335,  1.2678,  0.0320,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0656, -1.2755, -0.0632,  1.1836,  0.0953,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.6644,  0.3225, -0.1978,  1.4753, -1.9900, -0.5815,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.6872, -0.0

In [7]:
class CovarianceMatrix(nn.Module):
    def __init__(self, Dz=Dz, Dx=Dx):
        super(CovarianceMatrix, self).__init__()
        self.diagonale = nn.Linear(Dz, Dx)
        self.full = nn.Sequential(
            nn.Linear(Dz, Dx * Dx),
            nn.Unflatten(-1, (Dx, Dx)),
            )

    def forward(self, z):
        D = torch.diag_embed(torch.exp(self.diagonale(z)))  # Diagonal elements > 0
        T = torch.tril(self.full(z), diagonal=-1)  # Lower triangular matrix without diagonal
        L = D + T
        C = torch.einsum('...ij,...jk->...ik', L, L.transpose(-1, -2))
        return C
    
covariance_matrix = CovarianceMatrix()

In [8]:
z = torch.randn(Dz)
C = covariance_matrix(z)
print(f"z shape: {z.shape} => C shape: {C.shape}")

z = torch.randn(N,Dz)
C = covariance_matrix(z)
print(f"z shape: {z.shape} => C shape: {C.shape}")

z = torch.randn(B,N,Dz)
C = covariance_matrix(z)
print(f"z shape: {z.shape} => C shape: {C.shape}")

z shape: torch.Size([4]) => C shape: torch.Size([16, 16])
z shape: torch.Size([50, 4]) => C shape: torch.Size([50, 16, 16])
z shape: torch.Size([32, 50, 4]) => C shape: torch.Size([32, 50, 16, 16])


In [9]:
sample_C = C[4, 3]

print(sample_C)

tensor([[ 3.9435e+00,  7.7925e-01, -1.9096e+00,  1.6269e+00, -2.5272e-01,
         -1.6761e+00,  1.5084e+00, -1.3799e+00, -5.7419e-01, -4.3216e-02,
          1.6997e+00,  2.5782e+00, -2.3575e-01,  1.2093e-01,  8.5835e-02,
         -1.4866e+00],
        [ 7.7925e-01,  8.6263e+00, -2.6181e-01, -1.3465e+00,  8.4034e-01,
          1.5706e+00,  1.5629e+00, -3.5443e+00,  3.2278e+00,  2.9954e+00,
          3.1033e-01,  2.8698e+00, -8.4483e-01,  2.0412e+00,  5.1449e-01,
          6.2226e-01],
        [-1.9096e+00, -2.6181e-01,  9.2772e+00, -1.3081e+00,  1.2961e-02,
          1.3979e+00,  7.5842e-01,  2.6656e+00, -2.4583e-01,  8.6481e-01,
         -7.9920e-01, -1.7183e+00,  4.2322e-01,  3.0312e+00, -1.4723e+00,
          7.3420e-01],
        [ 1.6269e+00, -1.3465e+00, -1.3081e+00,  4.3061e+00, -1.0930e+00,
         -2.9179e+00,  1.4964e+00,  1.6030e+00, -3.5819e+00,  1.0260e-02,
          2.5197e+00,  5.6585e-01,  7.5451e-02,  1.6937e+00,  1.8191e+00,
         -3.9038e-02],
        [-2.5272e-01

In [10]:
L = torch.linalg.cholesky(sample_C, upper=False)
print(L)

print(L @ L.T)

tensor([[ 1.9858e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 3.9241e-01,  2.9107e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [-9.6162e-01,  3.9694e-02,  2.8898e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 8.1925e-01, -5.7304e-01, -1.7218e-01,  1.8102e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [-1.2726e-01

In [11]:
# Instantiation de MVN avec des matrices de covariance non diagonales

z = torch.randn(Dz)
mu = diagonale(z)
C = covariance_matrix(z)
print(f"z shape: {z.shape} => mu shape: {mu.shape}, covar shape: {C.shape}")
mvn = MVN(loc=mu, covariance_matrix=C)
print(f"mvn loc: {mvn.loc.shape}, covariance_matrix: {mvn.covariance_matrix.shape}")
print(f"mvn batch_shape: {mvn.batch_shape}, event_shape: {mvn.event_shape}")
print()

z = torch.randn(N,Dz)
mu = diagonale(z)
C = covariance_matrix(z)
print(f"z shape: {z.shape} => mu shape: {mu.shape}, covar shape: {C.shape}")
mvn = MVN(loc=mu, covariance_matrix=C)
print(f"mvn loc: {mvn.loc.shape}, covariance_matrix: {mvn.covariance_matrix.shape}")
print(f"mvn batch_shape: {mvn.batch_shape}, event_shape: {mvn.event_shape}")
print()

z = torch.randn(B,N,Dz)
mu = diagonale(z)
C = covariance_matrix(z)
print(f"z shape: {z.shape} => mu shape: {mu.shape}, covar shape: {C.shape}")
mvn = MVN(loc=mu, covariance_matrix=C)
print(f"mvn loc: {mvn.loc.shape}, covariance_matrix: {mvn.covariance_matrix.shape}")
print(f"mvn batch_shape: {mvn.batch_shape}, event_shape: {mvn.event_shape}")
print()

z shape: torch.Size([4]) => mu shape: torch.Size([16]), covar shape: torch.Size([16, 16])
mvn loc: torch.Size([16]), covariance_matrix: torch.Size([16, 16])
mvn batch_shape: torch.Size([]), event_shape: torch.Size([16])

z shape: torch.Size([50, 4]) => mu shape: torch.Size([50, 16]), covar shape: torch.Size([50, 16, 16])
mvn loc: torch.Size([50, 16]), covariance_matrix: torch.Size([50, 16, 16])
mvn batch_shape: torch.Size([50]), event_shape: torch.Size([16])

z shape: torch.Size([32, 50, 4]) => mu shape: torch.Size([32, 50, 16]), covar shape: torch.Size([32, 50, 16, 16])


ValueError: Expected parameter covariance_matrix (Tensor of shape (32, 50, 16, 16)) of distribution MultivariateNormal(loc: torch.Size([32, 50, 16]), covariance_matrix: torch.Size([32, 50, 16, 16])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[[ 6.1311e-01,  9.6867e-02, -1.8256e-01,  ...,  8.5216e-02,
           -3.7406e-01, -1.0978e-01],
          [ 9.6867e-02,  6.1187e+00, -3.0513e-01,  ...,  6.9263e-02,
           -2.3341e+00, -1.0623e-01],
          [-1.8256e-01, -3.0513e-01,  1.7707e+00,  ...,  6.4690e-01,
           -2.0326e-01, -1.7755e-01],
          ...,
          [ 8.5216e-02,  6.9263e-02,  6.4690e-01,  ...,  3.0457e+00,
           -3.2580e-01, -9.8104e-01],
          [-3.7406e-01, -2.3341e+00, -2.0326e-01,  ..., -3.2580e-01,
            3.1916e+00, -1.4539e-01],
          [-1.0978e-01, -1.0623e-01, -1.7755e-01,  ..., -9.8104e-01,
           -1.4539e-01,  3.0551e+00]],

         [[ 1.2580e+00, -9.7726e-01, -1.0637e+00,  ..., -1.0888e+00,
           -3.2800e-01,  8.7231e-03],
          [-9.7726e-01,  1.8207e+00,  1.3682e+00,  ...,  2.0798e+00,
            1.4427e-01,  1.2880e+00],
          [-1.0637e+00,  1.3682e+00,  1.6020e+01,  ...,  1.9797e+00,
           -1.3649e+00,  6.0206e-01],
          ...,
          [-1.0888e+00,  2.0798e+00,  1.9797e+00,  ...,  1.3905e+01,
            2.9002e+00,  5.1212e+00],
          [-3.2800e-01,  1.4427e-01, -1.3649e+00,  ...,  2.9002e+00,
            6.6408e+00,  4.2795e+00],
          [ 8.7231e-03,  1.2880e+00,  6.0206e-01,  ...,  5.1212e+00,
            4.2795e+00,  1.2571e+01]],

         [[ 4.4866e-01, -1.1575e+00,  1.1722e-01,  ..., -8.8916e-01,
           -4.5811e-01, -4.7986e-03],
          [-1.1575e+00,  3.0898e+00, -1.3983e-01,  ...,  2.4146e+00,
            9.4979e-01,  3.4066e-01],
          [ 1.1722e-01, -1.3983e-01,  4.5716e+00,  ..., -1.9404e+00,
           -1.7138e+00,  1.8176e+00],
          ...,
          [-8.8916e-01,  2.4146e+00, -1.9404e+00,  ...,  1.8255e+01,
            3.6954e+00,  6.7367e+00],
          [-4.5811e-01,  9.4979e-01, -1.7138e+00,  ...,  3.6954e+00,
            1.4673e+01,  6.1939e+00],
          [-4.7986e-03,  3.4066e-01,  1.8176e+00,  ...,  6.7367e+00,
            6.1939e+00,  1.8852e+01]],

         ...,

         [[ 5.7482e-02, -5.5743e-02, -5.8068e-02,  ..., -2.4287e-02,
           -1.4379e-01,  2.8405e-01],
          [-5.5743e-02,  8.2071e+00, -3.5811e-01,  ..., -4.0201e-01,
           -4.5553e+00, -4.4024e-01],
          [-5.8068e-02, -3.5811e-01,  5.3327e-01,  ...,  1.4691e-01,
            3.5676e-01, -8.7271e-01],
          ...,
          [-2.4287e-02, -4.0201e-01,  1.4691e-01,  ...,  2.0957e+00,
            1.2394e+00, -8.4986e-01],
          [-1.4379e-01, -4.5553e+00,  3.5676e-01,  ...,  1.2394e+00,
            1.3760e+01, -6.4384e+00],
          [ 2.8405e-01, -4.4024e-01, -8.7271e-01,  ..., -8.4986e-01,
           -6.4384e+00,  1.2919e+01]],

         [[ 3.7911e-01, -3.8551e-01, -1.2065e+00,  ..., -5.9746e-01,
            4.6331e-01,  7.4566e-01],
          [-3.8551e-01,  6.0121e+00,  2.0599e+00,  ...,  3.7496e+00,
            9.2739e-01,  1.6538e+00],
          [-1.2065e+00,  2.0599e+00,  9.6658e+00,  ...,  3.8062e+00,
           -2.1493e+00, -4.2700e+00],
          ...,
          [-5.9746e-01,  3.7496e+00,  3.8062e+00,  ...,  1.1881e+01,
            3.8286e+00,  1.0958e+00],
          [ 4.6331e-01,  9.2739e-01, -2.1493e+00,  ...,  3.8286e+00,
            8.8805e+00,  2.4993e+00],
          [ 7.4566e-01,  1.6538e+00, -4.2700e+00,  ...,  1.0958e+00,
            2.4993e+00,  3.1620e+01]],

         [[ 6.7998e+00, -1.9992e-01, -1.3658e+00,  ..., -6.5505e-01,
           -1.2980e+00, -3.2363e+00],
          [-1.9992e-01,  1.9644e+00,  4.3967e-01,  ...,  1.1969e+00,
           -9.5393e-02,  1.1327e+00],
          [-1.3658e+00,  4.3967e-01,  1.9722e+01,  ...,  2.9861e+00,
           -2.1197e+00,  3.1487e+00],
          ...,
          [-6.5505e-01,  1.1969e+00,  2.9861e+00,  ...,  1.9256e+01,
            1.6144e+00,  1.4257e+00],
          [-1.2980e+00, -9.5393e-02, -2.1197e+00,  ...,  1.6144e+00,
            5.7340e+00,  2.0649e+00],
          [-3.2363e+00,  1.1327e+00,  3.1487e+00,  ...,  1.4257e+00,
            2.0649e+00,  1.0092e+01]]],


        [[[ 9.6806e-01, -2.7192e-01, -3.1382e-01,  ..., -2.7407e-01,
           -5.5260e-01, -2.8414e-01],
          [-2.7192e-01,  2.5106e+00,  3.0499e-01,  ...,  7.0947e-01,
           -9.7360e-01,  7.8231e-01],
          [-3.1382e-01,  3.0499e-01,  4.3243e+00,  ...,  7.3658e-01,
           -6.8169e-01,  2.6421e-01],
          ...,
          [-2.7407e-01,  7.0947e-01,  7.3658e-01,  ...,  3.9933e+00,
            8.8988e-01,  6.5073e-01],
          [-5.5260e-01, -9.7360e-01, -6.8169e-01,  ...,  8.8988e-01,
            3.9154e+00,  5.5690e-01],
          [-2.8414e-01,  7.8231e-01,  2.6421e-01,  ...,  6.5073e-01,
            5.5690e-01,  3.4115e+00]],

         [[ 6.2306e-02, -1.0899e-01, -9.5755e-02,  ..., -7.2648e-02,
           -2.3573e-03,  3.0827e-01],
          [-1.0899e-01,  5.0906e+00, -2.4746e-01,  ..., -5.1404e-01,
           -2.1952e+00, -9.8156e-01],
          [-9.5755e-02, -2.4746e-01,  5.0841e-01,  ...,  2.1678e-01,
            2.5967e-02, -8.4960e-01],
          ...,
          [-7.2648e-02, -5.1404e-01,  2.1678e-01,  ...,  2.6274e+00,
            1.6444e+00, -2.2711e+00],
          [-2.3573e-03, -2.1952e+00,  2.5967e-02,  ...,  1.6444e+00,
            1.0550e+01, -3.5305e+00],
          [ 3.0827e-01, -9.8156e-01, -8.4960e-01,  ..., -2.2711e+00,
           -3.5305e+00,  1.5665e+01]],

         [[ 2.1737e-01,  5.2410e-02, -3.0947e-01,  ...,  4.4223e-03,
           -2.4086e-01,  3.1720e-01],
          [ 5.2410e-02,  1.4134e+01, -2.3708e-01,  ...,  1.3249e+00,
           -4.5368e+00,  8.5033e-01],
          [-3.0947e-01, -2.3708e-01,  2.0104e+00,  ...,  7.2450e-01,
            3.2914e-01, -1.4958e+00],
          ...,
          [ 4.4223e-03,  1.3249e+00,  7.2450e-01,  ...,  3.1166e+00,
            2.3590e-01, -6.9823e-01],
          [-2.4086e-01, -4.5368e+00,  3.2914e-01,  ...,  2.3590e-01,
            5.8486e+00, -2.6301e+00],
          [ 3.1720e-01,  8.5033e-01, -1.4958e+00,  ..., -6.9823e-01,
           -2.6301e+00,  6.9871e+00]],

         ...,

         [[ 1.9781e+00, -6.2138e-01,  2.2737e-01,  ..., -4.5184e-01,
           -1.3500e+00, -1.3159e+00],
          [-6.2138e-01,  1.1035e+00,  1.4248e-01,  ...,  4.4988e-01,
           -4.3964e-01,  9.3845e-01],
          [ 2.2737e-01,  1.4248e-01,  6.8380e+00,  ...,  1.0777e-01,
           -1.5464e+00,  1.5037e+00],
          ...,
          [-4.5184e-01,  4.4988e-01,  1.0777e-01,  ...,  1.0151e+01,
            3.2021e+00,  2.1808e+00],
          [-1.3500e+00, -4.3964e-01, -1.5464e+00,  ...,  3.2021e+00,
            9.8916e+00,  2.0793e+00],
          [-1.3159e+00,  9.3845e-01,  1.5037e+00,  ...,  2.1808e+00,
            2.0793e+00,  8.7337e+00]],

         [[ 1.3532e+00, -9.8744e-02, -1.2141e-01,  ..., -6.4273e-02,
           -4.1405e-01, -7.2657e-01],
          [-9.8744e-02,  2.4142e+00, -5.2703e-02,  ...,  2.8179e-02,
           -7.9062e-01,  1.1177e-01],
          [-1.2141e-01, -5.2703e-02,  2.5951e+00,  ...,  6.1149e-01,
           -8.0089e-01,  5.5233e-01],
          ...,
          [-6.4273e-02,  2.8179e-02,  6.1149e-01,  ...,  3.6413e+00,
           -1.0087e-02, -5.5448e-01],
          [-4.1405e-01, -7.9062e-01, -8.0089e-01,  ..., -1.0087e-02,
            3.0092e+00,  1.0334e+00],
          [-7.2657e-01,  1.1177e-01,  5.5233e-01,  ..., -5.5448e-01,
            1.0334e+00,  2.7066e+00]],

         [[ 2.6875e+00, -4.2469e-01,  1.1428e+00,  ..., -1.9666e-02,
           -1.4289e+00, -2.3692e+00],
          [-4.2469e-01,  7.4561e-01, -2.0237e-01,  ..., -2.8111e-01,
           -5.0106e-01,  3.3325e-01],
          [ 1.1428e+00, -2.0237e-01,  3.4503e+00,  ...,  7.9411e-03,
           -1.7317e+00,  7.6221e-01],
          ...,
          [-1.9666e-02, -2.8111e-01,  7.9411e-03,  ...,  7.8005e+00,
            3.3322e+00, -1.1495e-01],
          [-1.4289e+00, -5.0106e-01, -1.7317e+00,  ...,  3.3322e+00,
            9.2808e+00,  1.8908e+00],
          [-2.3692e+00,  3.3325e-01,  7.6221e-01,  ..., -1.1495e-01,
            1.8908e+00,  7.0400e+00]]],


        [[[ 1.9709e+00,  5.4997e-01,  1.2695e+00,  ...,  8.8271e-01,
           -8.0090e-01, -2.0800e+00],
          [ 5.4997e-01,  2.4003e+00, -3.4700e-01,  ..., -1.3691e+00,
           -1.6224e+00, -2.0018e+00],
          [ 1.2695e+00, -3.4700e-01,  1.7274e+00,  ...,  1.3918e+00,
           -6.8175e-01, -1.3797e-01],
          ...,
          [ 8.8271e-01, -1.3691e+00,  1.3918e+00,  ...,  9.7795e+00,
           -1.6678e-01, -1.6310e+00],
          [-8.0090e-01, -1.6224e+00, -6.8175e-01,  ..., -1.6678e-01,
            1.1301e+01,  6.1771e+00],
          [-2.0800e+00, -2.0018e+00, -1.3797e-01,  ..., -1.6310e+00,
            6.1771e+00,  1.3581e+01]],

         [[ 9.3007e-01, -9.3663e-03, -7.1618e-01,  ..., -1.6414e-01,
           -3.3031e-01, -4.5510e-02],
          [-9.3663e-03,  6.0900e+00,  2.7575e-01,  ...,  1.5263e+00,
           -1.3287e+00,  1.1393e+00],
          [-7.1618e-01,  2.7575e-01,  5.2019e+00,  ...,  1.5276e+00,
           -4.3007e-01, -5.5196e-01],
          ...,
          [-1.6414e-01,  1.5263e+00,  1.5276e+00,  ...,  4.5305e+00,
            2.5261e-01,  2.0176e-01],
          [-3.3031e-01, -1.3287e+00, -4.3007e-01,  ...,  2.5261e-01,
            2.4838e+00,  4.1779e-01],
          [-4.5510e-02,  1.1393e+00, -5.5196e-01,  ...,  2.0176e-01,
            4.1779e-01,  3.3826e+00]],

         [[ 6.4854e-01, -4.6615e-01,  5.7102e-01,  ..., -1.8334e-01,
           -8.4860e-01, -5.1990e-01],
          [-4.6615e-01,  9.9318e-01, -3.9296e-01,  ..., -1.5567e-01,
           -5.1588e-01,  4.3392e-01],
          [ 5.7102e-01, -3.9296e-01,  2.1393e+00,  ..., -5.5335e-01,
           -1.3824e+00,  3.3355e-01],
          ...,
          [-1.8334e-01, -1.5567e-01, -5.5335e-01,  ...,  5.3296e+00,
            3.5921e+00,  2.2548e+00],
          [-8.4860e-01, -5.1588e-01, -1.3824e+00,  ...,  3.5921e+00,
            1.1469e+01,  1.2470e+00],
          [-5.1990e-01,  4.3392e-01,  3.3355e-01,  ...,  2.2548e+00,
            1.2470e+00,  7.5826e+00]],

         ...,

         [[ 9.7410e+00, -1.6798e-01, -2.3814e+00,  ..., -9.6065e-01,
           -7.7973e-01, -4.0561e+00],
          [-1.6798e-01,  2.0656e+00,  4.9375e-01,  ...,  1.4438e+00,
            4.1042e-01,  1.2318e+00],
          [-2.3814e+00,  4.9375e-01,  2.6449e+01,  ...,  4.1953e+00,
           -2.7969e+00,  3.9464e+00],
          ...,
          [-9.6065e-01,  1.4438e+00,  4.1953e+00,  ...,  2.4116e+01,
            8.5514e-01,  2.0763e+00],
          [-7.7973e-01,  4.1042e-01, -2.7969e+00,  ...,  8.5514e-01,
            5.4332e+00,  2.3535e+00],
          [-4.0561e+00,  1.2318e+00,  3.9464e+00,  ...,  2.0763e+00,
            2.3535e+00,  1.2494e+01]],

         [[ 7.4111e-01, -2.7447e-01,  1.5865e-01,  ..., -1.1842e-01,
           -2.7464e-01, -3.8585e-01],
          [-2.7447e-01,  1.5188e+00, -1.9068e-01,  ..., -3.5933e-01,
           -6.8371e-01, -5.1081e-02],
          [ 1.5865e-01, -1.9068e-01,  1.2781e+00,  ...,  1.1601e-01,
           -6.5392e-01,  3.7449e-01],
          ...,
          [-1.1842e-01, -3.5933e-01,  1.1601e-01,  ...,  2.2631e+00,
            7.8158e-01, -6.6928e-02],
          [-2.7464e-01, -6.8371e-01, -6.5392e-01,  ...,  7.8158e-01,
            4.1500e+00,  1.1052e+00],
          [-3.8585e-01, -5.1081e-02,  3.7449e-01,  ..., -6.6928e-02,
            1.1052e+00,  3.0237e+00]],

         [[ 1.9778e+00,  1.0240e+00, -7.1885e-01,  ...,  7.0488e-01,
           -2.1515e-01, -9.5577e-01],
          [ 1.0240e+00,  1.6222e+01, -1.3469e+00,  ...,  6.7523e-01,
           -1.6929e+00, -1.5817e+00],
          [-7.1885e-01, -1.3469e+00,  2.8265e+00,  ...,  1.4638e+00,
           -4.8999e-01,  2.7006e-01],
          ...,
          [ 7.0488e-01,  6.7523e-01,  1.4638e+00,  ...,  1.0917e+01,
           -3.3087e+00, -1.8187e+00],
          [-2.1515e-01, -1.6929e+00, -4.8999e-01,  ..., -3.3087e+00,
            4.6748e+00,  3.4837e+00],
          [-9.5577e-01, -1.5817e+00,  2.7006e-01,  ..., -1.8187e+00,
            3.4837e+00,  1.1110e+01]]],


        ...,


        [[[ 1.2750e-01, -1.1079e-01, -2.6879e-01,  ..., -1.1789e-01,
           -2.5244e-02,  3.7778e-01],
          [-1.1079e-01,  6.9972e+00,  1.5537e-01,  ...,  7.0701e-01,
           -2.0897e+00,  1.4617e-01],
          [-2.6879e-01,  1.5537e-01,  1.5124e+00,  ...,  5.6199e-01,
           -1.3074e-01, -1.5799e+00],
          ...,
          [-1.1789e-01,  7.0701e-01,  5.6199e-01,  ...,  2.3898e+00,
            1.1970e+00, -1.8169e+00],
          [-2.5244e-02, -2.0897e+00, -1.3074e-01,  ...,  1.1970e+00,
            6.1672e+00, -2.4703e+00],
          [ 3.7778e-01,  1.4617e-01, -1.5799e+00,  ..., -1.8169e+00,
           -2.4703e+00,  9.7849e+00]],

         [[ 1.4899e-01, -3.2661e-01, -2.3810e-01,  ..., -2.8864e-01,
           -2.7660e-02,  3.6731e-01],
          [-3.2661e-01,  2.5619e+00,  7.2223e-01,  ...,  1.0613e+00,
           -7.8302e-01, -1.6952e-01],
          [-2.3810e-01,  7.2223e-01,  1.7428e+00,  ...,  4.1995e-01,
           -4.5517e-01, -1.0494e+00],
          ...,
          [-2.8864e-01,  1.0613e+00,  4.1995e-01,  ...,  3.3140e+00,
            1.4054e+00, -1.0573e+00],
          [-2.7660e-02, -7.8302e-01, -4.5517e-01,  ...,  1.4054e+00,
            5.4998e+00, -1.1769e+00],
          [ 3.6731e-01, -1.6952e-01, -1.0494e+00,  ..., -1.0573e+00,
           -1.1769e+00,  7.0248e+00]],

         [[ 3.3093e-01,  3.2952e-01, -6.7226e-01,  ...,  1.4711e-01,
            4.2943e-02,  3.9816e-01],
          [ 3.2952e-01,  4.2433e+01, -1.9137e+00,  ...,  2.8154e+00,
           -3.6647e+00,  1.1578e-01],
          [-6.7226e-01, -1.9137e+00,  2.8598e+00,  ...,  9.8451e-01,
           -1.3493e-01, -2.0646e+00],
          ...,
          [ 1.4711e-01,  2.8154e+00,  9.8451e-01,  ...,  9.1319e+00,
           -4.5097e-01, -1.1186e+00],
          [ 4.2943e-02, -3.6647e+00, -1.3493e-01,  ..., -4.5097e-01,
            6.3327e+00,  6.6757e-01],
          [ 3.9816e-01,  1.1578e-01, -2.0646e+00,  ..., -1.1186e+00,
            6.6757e-01,  2.5877e+01]],

         ...,

         [[ 5.1658e-01, -3.0100e-01, -1.5463e-01,  ..., -2.1021e-01,
            2.9792e-01, -1.4050e-02],
          [-3.0100e-01,  1.7392e+00, -1.4632e-01,  ..., -3.8313e-01,
           -1.7074e-01, -4.0304e-01],
          [-1.5463e-01, -1.4632e-01,  7.9298e-01,  ...,  2.8602e-01,
           -7.3355e-01,  2.5405e-01],
          ...,
          [-2.1021e-01, -3.8313e-01,  2.8602e-01,  ...,  4.8975e+00,
           -7.9169e-01,  4.2069e-01],
          [ 2.9792e-01, -1.7074e-01, -7.3355e-01,  ..., -7.9169e-01,
            8.3937e+00,  2.9547e+00],
          [-1.4050e-02, -4.0304e-01,  2.5405e-01,  ...,  4.2069e-01,
            2.9547e+00,  1.5896e+01]],

         [[ 2.5952e-01, -4.6112e-01, -2.0071e-01,  ..., -3.6870e-01,
            2.6381e-01,  2.6023e-01],
          [-4.6112e-01,  1.6860e+00,  3.2709e-01,  ...,  4.6641e-01,
           -4.0827e-01, -4.7468e-01],
          [-2.0071e-01,  3.2709e-01,  8.7106e-01,  ...,  1.8016e-01,
           -8.3511e-01, -1.4601e-01],
          ...,
          [-3.6870e-01,  4.6641e-01,  1.8016e-01,  ...,  5.3975e+00,
            5.9604e-01, -2.6826e-01],
          [ 2.6381e-01, -4.0827e-01, -8.3511e-01,  ...,  5.9604e-01,
            7.5772e+00,  9.6698e-01],
          [ 2.6023e-01, -4.7468e-01, -1.4601e-01,  ..., -2.6826e-01,
            9.6698e-01,  1.4379e+01]],

         [[ 2.0150e-01, -2.9669e-01, -1.0560e-01,  ..., -2.0367e-01,
            2.7501e-01,  2.3461e-01],
          [-2.9669e-01,  1.7739e+00, -1.1217e-01,  ..., -3.6863e-01,
           -4.3589e-01, -8.4004e-01],
          [-1.0560e-01, -1.1217e-01,  4.4794e-01,  ...,  2.2684e-01,
           -5.8444e-01, -8.1467e-03],
          ...,
          [-2.0367e-01, -3.6863e-01,  2.2684e-01,  ...,  6.4674e+00,
           -9.9782e-03,  3.7078e-02],
          [ 2.7501e-01, -4.3589e-01, -5.8444e-01,  ..., -9.9782e-03,
            1.2260e+01,  2.6764e+00],
          [ 2.3461e-01, -8.4004e-01, -8.1467e-03,  ...,  3.7078e-02,
            2.6764e+00,  2.8795e+01]]],


        [[[ 2.4405e+00, -7.7844e-01,  2.8581e-01,  ..., -5.0498e-01,
            1.8256e-02, -1.6847e+00],
          [-7.7844e-01,  7.8232e-01, -1.1521e-01,  ..., -4.6517e-02,
            3.7464e-02,  4.7527e-01],
          [ 2.8581e-01, -1.1521e-01,  2.3150e+00,  ...,  2.1248e-02,
           -1.3713e+00,  1.2459e+00],
          ...,
          [-5.0498e-01, -4.6517e-02,  2.1248e-02,  ...,  5.8055e+00,
            2.4000e-01,  1.1402e+00],
          [ 1.8256e-02,  3.7464e-02, -1.3713e+00,  ...,  2.4000e-01,
            6.4009e+00,  2.4126e+00],
          [-1.6847e+00,  4.7527e-01,  1.2459e+00,  ...,  1.1402e+00,
            2.4126e+00,  7.9013e+00]],

         [[ 3.1330e+00,  1.8840e-01, -1.8671e+00,  ..., -3.6198e-01,
           -1.0673e-01, -9.4573e-01],
          [ 1.8840e-01,  5.8049e+00,  3.5227e-01,  ...,  2.2322e+00,
            2.5489e-01,  1.4691e+00],
          [-1.8671e+00,  3.5227e-01,  1.2766e+01,  ...,  3.3840e+00,
           -1.4161e+00,  4.5510e-01],
          ...,
          [-3.6198e-01,  2.2322e+00,  3.3840e+00,  ...,  1.0682e+01,
           -6.7884e-02,  1.3876e+00],
          [-1.0673e-01,  2.5489e-01, -1.4161e+00,  ..., -6.7884e-02,
            2.9041e+00,  2.0635e+00],
          [-9.4573e-01,  1.4691e+00,  4.5510e-01,  ...,  1.3876e+00,
            2.0635e+00,  7.2755e+00]],

         [[ 3.2568e-01,  6.0559e-02,  1.0811e-01,  ...,  1.2731e-01,
           -5.7820e-01, -2.1884e-02],
          [ 6.0559e-02,  6.0281e+00, -3.3258e-01,  ..., -4.4359e-01,
           -4.2111e+00, -2.5397e-01],
          [ 1.0811e-01, -3.3258e-01,  1.2048e+00,  ...,  3.9976e-01,
           -1.1229e-01, -2.2724e-01],
          ...,
          [ 1.2731e-01, -4.4359e-01,  3.9976e-01,  ...,  3.5459e+00,
            4.1243e-01,  1.1402e-01],
          [-5.7820e-01, -4.2111e+00, -1.1229e-01,  ...,  4.1243e-01,
            8.7461e+00, -1.9516e+00],
          [-2.1884e-02, -2.5397e-01, -2.2724e-01,  ...,  1.1402e-01,
           -1.9516e+00,  7.2195e+00]],

         ...,

         [[ 6.9700e-02, -1.2319e-01, -5.2072e-01,  ..., -2.0782e-01,
            9.8525e-02,  5.6072e-01],
          [-1.2319e-01,  1.8375e+01,  2.0099e+00,  ...,  5.6569e+00,
           -1.5277e+00,  2.9388e+00],
          [-5.2072e-01,  2.0099e+00,  6.2890e+00,  ...,  2.8054e+00,
           -7.9038e-01, -6.5338e+00],
          ...,
          [-2.0782e-01,  5.6569e+00,  2.8054e+00,  ...,  9.9615e+00,
            5.3409e+00, -2.5418e+00],
          [ 9.8525e-02, -1.5277e+00, -7.9038e-01,  ...,  5.3409e+00,
            1.4554e+01, -1.3084e+00],
          [ 5.6072e-01,  2.9388e+00, -6.5338e+00,  ..., -2.5418e+00,
           -1.3084e+00,  3.5962e+01]],

         [[ 6.3476e-01, -6.8938e-01,  3.3313e-01,  ..., -4.0814e-01,
           -1.1174e-01, -3.5765e-01],
          [-6.8938e-01,  1.1116e+00, -3.8841e-01,  ...,  1.2487e-01,
           -1.3243e-01,  3.1150e-01],
          [ 3.3313e-01, -3.8841e-01,  1.1325e+00,  ..., -5.4055e-01,
           -8.3353e-01,  5.5264e-01],
          ...,
          [-4.0814e-01,  1.2487e-01, -5.4055e-01,  ...,  3.7520e+00,
            2.1745e+00,  1.5953e+00],
          [-1.1174e-01, -1.3243e-01, -8.3353e-01,  ...,  2.1745e+00,
            5.9899e+00,  1.6581e+00],
          [-3.5765e-01,  3.1150e-01,  5.5264e-01,  ...,  1.5953e+00,
            1.6581e+00,  4.8738e+00]],

         [[ 4.6434e-01, -7.6840e-01, -2.0618e-02,  ..., -6.0285e-01,
           -4.6919e-01,  2.1332e-02],
          [-7.6840e-01,  1.7092e+00,  2.6273e-01,  ...,  1.2470e+00,
            2.0791e-01,  4.8141e-01],
          [-2.0618e-02,  2.6273e-01,  3.8132e+00,  ..., -6.7888e-01,
           -1.1303e+00,  8.0476e-01],
          ...,
          [-6.0285e-01,  1.2470e+00, -6.7888e-01,  ...,  8.4673e+00,
            2.3396e+00,  3.5794e+00],
          [-4.6919e-01,  2.0791e-01, -1.1303e+00,  ...,  2.3396e+00,
            9.7624e+00,  2.7858e+00],
          [ 2.1332e-02,  4.8141e-01,  8.0476e-01,  ...,  3.5794e+00,
            2.7858e+00,  1.0615e+01]]],


        [[[ 1.8359e+00,  4.1810e-01, -2.2939e+00,  ..., -2.7460e-01,
            1.3913e-02,  1.5398e-01],
          [ 4.1810e-01,  2.0170e+01,  6.4525e-01,  ...,  6.3217e+00,
            2.8789e-01,  4.0785e+00],
          [-2.2939e+00,  6.4525e-01,  1.8438e+01,  ...,  5.4428e+00,
           -7.2185e-01, -2.9205e+00],
          ...,
          [-2.7460e-01,  6.3217e+00,  5.4428e+00,  ...,  1.3174e+01,
            1.7566e+00,  3.0897e+00],
          [ 1.3913e-02,  2.8789e-01, -7.2185e-01,  ...,  1.7566e+00,
            4.7362e+00,  2.9085e+00],
          [ 1.5398e-01,  4.0785e+00, -2.9205e+00,  ...,  3.0897e+00,
            2.9085e+00,  1.1613e+01]],

         [[ 5.9233e-02, -7.2817e-02, -2.0129e-01,  ..., -7.6232e-02,
           -2.6886e-02,  3.6897e-01],
          [-7.2817e-02,  1.0675e+01,  3.3798e-02,  ...,  7.5852e-01,
           -3.4272e+00,  1.9324e-02],
          [-2.0129e-01,  3.3798e-02,  1.3063e+00,  ...,  4.9787e-01,
            9.6151e-02, -2.1508e+00],
          ...,
          [-7.6232e-02,  7.5852e-01,  4.9787e-01,  ...,  3.0071e+00,
            1.8625e+00, -2.6354e+00],
          [-2.6886e-02, -3.4272e+00,  9.6151e-02,  ...,  1.8625e+00,
            1.1245e+01, -4.9539e+00],
          [ 3.6897e-01,  1.9324e-02, -2.1508e+00,  ..., -2.6354e+00,
           -4.9539e+00,  1.6066e+01]],

         [[ 1.3375e+00, -9.3252e-02, -1.0132e+00,  ..., -2.9818e-01,
            2.6510e-01, -1.9499e-01],
          [-9.3252e-02,  4.2974e+00,  1.5921e-01,  ...,  1.0084e+00,
            2.6866e-01,  5.8460e-01],
          [-1.0132e+00,  1.5921e-01,  4.5210e+00,  ...,  1.5269e+00,
           -1.2511e+00, -1.1499e-02],
          ...,
          [-2.9818e-01,  1.0084e+00,  1.5269e+00,  ...,  6.0055e+00,
           -6.8673e-01,  4.6854e-01],
          [ 2.6510e-01,  2.6866e-01, -1.2511e+00,  ..., -6.8673e-01,
            3.4046e+00,  1.8209e+00],
          [-1.9499e-01,  5.8460e-01, -1.1499e-02,  ...,  4.6854e-01,
            1.8209e+00,  8.6216e+00]],

         ...,

         [[ 2.9112e-01, -4.5088e-01, -5.1864e-01,  ..., -4.7794e-01,
           -1.8717e-01,  4.3301e-01],
          [-4.5088e-01,  2.8692e+00,  1.3953e+00,  ...,  2.1934e+00,
           -6.0992e-01,  8.8211e-01],
          [-5.1864e-01,  1.3953e+00,  6.3252e+00,  ...,  1.3938e+00,
           -3.8714e-01, -1.6050e+00],
          ...,
          [-4.7794e-01,  2.1934e+00,  1.3938e+00,  ...,  6.8749e+00,
            1.4879e+00,  1.8644e+00],
          [-1.8717e-01, -6.0992e-01, -3.8714e-01,  ...,  1.4879e+00,
            7.6457e+00,  1.9204e+00],
          [ 4.3301e-01,  8.8211e-01, -1.6050e+00,  ...,  1.8644e+00,
            1.9204e+00,  1.1169e+01]],

         [[ 9.6610e-01, -2.6776e-01,  1.4628e-01,  ..., -1.3139e-01,
           -5.7862e-01, -5.5899e-01],
          [-2.6776e-01,  1.6488e+00, -5.6185e-02,  ..., -8.2482e-02,
           -8.8530e-01,  2.3940e-01],
          [ 1.4628e-01, -5.6185e-02,  2.1881e+00,  ...,  1.8271e-01,
           -8.1480e-01,  4.6312e-01],
          ...,
          [-1.3139e-01, -8.2482e-02,  1.8271e-01,  ...,  2.9520e+00,
            1.1485e+00,  3.4865e-02],
          [-5.7862e-01, -8.8530e-01, -8.1480e-01,  ...,  1.1485e+00,
            4.4068e+00,  6.3282e-01],
          [-5.5899e-01,  2.3940e-01,  4.6312e-01,  ...,  3.4865e-02,
            6.3282e-01,  2.6567e+00]],

         [[ 6.5777e-01, -1.9410e-01, -2.6252e-01,  ..., -1.8092e-01,
           -3.0139e-01, -6.9762e-02],
          [-1.9410e-01,  2.9833e+00,  1.3677e-01,  ...,  4.0474e-01,
           -1.0701e+00,  4.1931e-01],
          [-2.6252e-01,  1.3677e-01,  2.5064e+00,  ...,  5.5279e-01,
           -5.2162e-01, -3.3300e-02],
          ...,
          [-1.8092e-01,  4.0474e-01,  5.5279e-01,  ...,  2.3323e+00,
            4.4356e-01, -1.7217e-01],
          [-3.0139e-01, -1.0701e+00, -5.2162e-01,  ...,  4.4356e-01,
            2.7226e+00,  4.7279e-02],
          [-6.9762e-02,  4.1931e-01, -3.3300e-02,  ..., -1.7217e-01,
            4.7279e-02,  2.2213e+00]]]], grad_fn=<ExpandBackward0>)

In [12]:
# Test for positive definiteness of covariance matrix by brute force sampling...

N_TESTS = int(1e+5)

for _ in tqdm(range(N_TESTS)):
    
    z = torch.randn(B,N,Dz)
    mu = diagonale(z)
    C = covariance_matrix(z)
    mvn = MVN(loc=mu, covariance_matrix=C)

  0%|          | 0/100000 [00:00<?, ?it/s]


ValueError: Expected parameter covariance_matrix (Tensor of shape (32, 50, 16, 16)) of distribution MultivariateNormal(loc: torch.Size([32, 50, 16]), covariance_matrix: torch.Size([32, 50, 16, 16])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[[ 5.8506e-01, -9.9067e-01,  3.1407e-02,  ..., -7.3362e-01,
            3.2397e-01, -1.1008e-01],
          [-9.9067e-01,  1.8618e+00, -2.1245e-02,  ...,  1.1000e+00,
           -4.3463e-01,  2.2256e-01],
          [ 3.1407e-02, -2.1245e-02,  1.0783e+00,  ..., -5.8787e-01,
           -9.8706e-01,  7.9053e-01],
          ...,
          [-7.3362e-01,  1.1000e+00, -5.8787e-01,  ...,  7.8268e+00,
            1.1289e+00,  2.2967e+00],
          [ 3.2397e-01, -4.3463e-01, -9.8706e-01,  ...,  1.1289e+00,
            7.6800e+00,  1.6737e+00],
          [-1.1008e-01,  2.2256e-01,  7.9053e-01,  ...,  2.2967e+00,
            1.6737e+00,  1.0396e+01]],

         [[ 1.6187e+00, -4.1551e-01,  3.2606e-01,  ..., -1.8111e-01,
           -8.1370e-02, -1.1681e+00],
          [-4.1551e-01,  9.7291e-01, -2.1271e-01,  ..., -3.7350e-01,
           -1.6467e-01,  4.0391e-02],
          [ 3.2606e-01, -2.1271e-01,  1.5266e+00,  ...,  1.6629e-01,
           -9.7934e-01,  7.4380e-01],
          ...,
          [-1.8111e-01, -3.7350e-01,  1.6629e-01,  ...,  4.4218e+00,
            5.8417e-02,  4.3728e-01],
          [-8.1370e-02, -1.6467e-01, -9.7934e-01,  ...,  5.8417e-02,
            6.2317e+00,  2.7782e+00],
          [-1.1681e+00,  4.0391e-02,  7.4380e-01,  ...,  4.3728e-01,
            2.7782e+00,  6.8714e+00]],

         [[ 6.7049e-02, -9.1602e-02,  1.4981e-01,  ...,  1.3976e-02,
           -8.4205e-02,  1.5930e-01],
          [-9.1602e-02,  2.5819e+00, -8.9945e-01,  ..., -1.8414e+00,
           -2.1673e+00, -1.5396e+00],
          [ 1.4981e-01, -8.9945e-01,  6.4780e-01,  ...,  4.8495e-01,
            3.0607e-01,  6.8324e-01],
          ...,
          [ 1.3976e-02, -1.8414e+00,  4.8495e-01,  ...,  3.9549e+00,
            2.5879e+00,  9.4202e-01],
          [-8.4205e-02, -2.1673e+00,  3.0607e-01,  ...,  2.5879e+00,
            1.4146e+01, -1.0560e+00],
          [ 1.5930e-01, -1.5396e+00,  6.8324e-01,  ...,  9.4202e-01,
           -1.0560e+00,  1.6164e+01]],

         ...,

         [[ 9.4643e-01,  1.2511e-02, -2.1227e-01,  ..., -6.8034e-03,
           -6.3265e-01, -3.5714e-01],
          [ 1.2511e-02,  4.1997e+00,  3.1674e-02,  ...,  4.8463e-01,
           -1.8857e+00,  4.4527e-01],
          [-2.1227e-01,  3.1674e-02,  3.2088e+00,  ...,  8.0374e-01,
           -4.4000e-01,  6.3629e-02],
          ...,
          [-6.8034e-03,  4.8463e-01,  8.0374e-01,  ...,  3.7845e+00,
            2.4540e-01, -3.2968e-01],
          [-6.3265e-01, -1.8857e+00, -4.4000e-01,  ...,  2.4540e-01,
            3.8800e+00,  6.6421e-02],
          [-3.5714e-01,  4.4527e-01,  6.3629e-02,  ..., -3.2968e-01,
            6.6421e-02,  2.9553e+00]],

         [[ 8.1171e-01, -2.0324e-01, -7.4102e-01,  ..., -3.5202e-01,
           -6.6506e-01,  9.6848e-02],
          [-2.0324e-01,  5.0676e+00,  9.1615e-01,  ...,  2.3904e+00,
           -1.7689e+00,  2.0857e+00],
          [-7.4102e-01,  9.1615e-01,  1.0152e+01,  ...,  2.0959e+00,
           -2.4517e-02, -1.0840e+00],
          ...,
          [-3.5202e-01,  2.3904e+00,  2.0959e+00,  ...,  7.8997e+00,
            1.3191e+00,  2.8645e+00],
          [-6.6506e-01, -1.7689e+00, -2.4517e-02,  ...,  1.3191e+00,
            8.2743e+00,  1.8582e+00],
          [ 9.6848e-02,  2.0857e+00, -1.0840e+00,  ...,  2.8645e+00,
            1.8582e+00,  9.9627e+00]],

         [[ 7.9983e-01, -3.1409e-01, -4.1947e-01,  ..., -3.2557e-01,
           -2.2194e-01, -8.4667e-02],
          [-3.1409e-01,  2.4860e+00,  3.3045e-01,  ...,  6.8251e-01,
           -5.9803e-01,  6.1541e-01],
          [-4.1947e-01,  3.3045e-01,  3.4494e+00,  ...,  7.3368e-01,
           -7.3284e-01,  7.2782e-02],
          ...,
          [-3.2557e-01,  6.8251e-01,  7.3368e-01,  ...,  3.0172e+00,
            5.7030e-01,  2.5282e-01],
          [-2.2194e-01, -5.9803e-01, -7.3284e-01,  ...,  5.7030e-01,
            2.3233e+00,  4.5472e-01],
          [-8.4667e-02,  6.1541e-01,  7.2782e-02,  ...,  2.5282e-01,
            4.5472e-01,  2.5555e+00]]],


        [[[ 2.5922e-01,  4.7955e-02, -7.6564e-01,  ..., -1.2555e-01,
           -2.4631e-02,  5.5054e-01],
          [ 4.7955e-02,  2.4618e+01,  4.6761e-01,  ...,  4.9721e+00,
           -2.7581e+00,  3.2276e+00],
          [-7.6564e-01,  4.6761e-01,  5.7166e+00,  ...,  2.1454e+00,
           -2.8561e-02, -3.8226e+00],
          ...,
          [-1.2555e-01,  4.9721e+00,  2.1454e+00,  ...,  6.6368e+00,
            2.1016e+00,  7.7444e-03],
          [-2.4631e-02, -2.7581e+00, -2.8561e-02,  ...,  2.1016e+00,
            6.1525e+00, -6.5883e-01],
          [ 5.5054e-01,  3.2276e+00, -3.8226e+00,  ...,  7.7444e-03,
           -6.5883e-01,  1.3400e+01]],

         [[ 8.1432e-01, -1.0773e-01, -1.4218e+00,  ..., -4.1356e-01,
            6.6174e-01,  4.6206e-01],
          [-1.0773e-01,  8.1991e+00,  4.6988e-01,  ...,  2.4776e+00,
            1.5854e+00,  1.2677e+00],
          [-1.4218e+00,  4.6988e-01,  6.4009e+00,  ...,  2.5124e+00,
           -2.0730e+00, -1.9496e+00],
          ...,
          [-4.1356e-01,  2.4776e+00,  2.5124e+00,  ...,  1.0365e+01,
            7.3882e-01,  1.5151e+00],
          [ 6.6174e-01,  1.5854e+00, -2.0730e+00,  ...,  7.3882e-01,
            8.4393e+00,  2.9091e+00],
          [ 4.6206e-01,  1.2677e+00, -1.9496e+00,  ...,  1.5151e+00,
            2.9091e+00,  3.7205e+01]],

         [[ 1.7783e+00, -6.4468e-01,  1.1841e+00,  ..., -1.0279e-01,
            1.5282e-01, -1.7809e+00],
          [-6.4468e-01,  5.4118e-01, -6.1543e-01,  ..., -6.2215e-01,
           -9.5846e-02,  1.9139e-01],
          [ 1.1841e+00, -6.1543e-01,  1.4476e+00,  ...,  1.9486e-01,
           -7.0716e-01,  9.6670e-02],
          ...,
          [-1.0279e-01, -6.2215e-01,  1.9486e-01,  ...,  8.3857e+00,
            9.4943e-01,  3.0449e+00],
          [ 1.5282e-01, -9.5846e-02, -7.0716e-01,  ...,  9.4943e-01,
            1.5627e+01,  6.8956e+00],
          [-1.7809e+00,  1.9139e-01,  9.6670e-02,  ...,  3.0449e+00,
            6.8956e+00,  1.8131e+01]],

         ...,

         [[ 1.0080e-01, -2.3138e-01,  7.7465e-02,  ..., -1.2150e-01,
           -1.0780e-01,  2.0263e-01],
          [-2.3138e-01,  1.8435e+00, -3.4222e-01,  ..., -3.7330e-01,
           -1.0966e+00, -7.1685e-01],
          [ 7.7465e-02, -3.4222e-01,  4.3896e-01,  ..., -1.9706e-01,
           -1.7685e-01,  1.2680e-01],
          ...,
          [-1.2150e-01, -3.7330e-01, -1.9706e-01,  ...,  2.2562e+00,
            2.2835e+00,  6.3397e-02],
          [-1.0780e-01, -1.0966e+00, -1.7685e-01,  ...,  2.2835e+00,
            8.7904e+00, -1.9274e+00],
          [ 2.0263e-01, -7.1685e-01,  1.2680e-01,  ...,  6.3397e-02,
           -1.9274e+00,  7.6384e+00]],

         [[ 1.7336e-01, -2.1038e-01, -2.8787e-01,  ..., -2.0046e-01,
            3.4042e-02,  3.6361e-01],
          [-2.1038e-01,  3.8798e+00,  3.6328e-01,  ...,  6.0417e-01,
           -1.0379e+00, -5.0353e-02],
          [-2.8787e-01,  3.6328e-01,  1.5226e+00,  ...,  5.3703e-01,
           -4.3039e-01, -1.1560e+00],
          ...,
          [-2.0046e-01,  6.0417e-01,  5.3703e-01,  ...,  2.5468e+00,
            1.1046e+00, -1.5415e+00],
          [ 3.4042e-02, -1.0379e+00, -4.3039e-01,  ...,  1.1046e+00,
            4.8627e+00, -1.2473e+00],
          [ 3.6361e-01, -5.0353e-02, -1.1560e+00,  ..., -1.5415e+00,
           -1.2473e+00,  8.4273e+00]],

         [[ 1.5654e+00, -9.5508e-01, -3.0696e-01,  ..., -8.1758e-01,
            1.5890e-01, -6.9397e-01],
          [-9.5508e-01,  1.1217e+00,  2.9053e-01,  ...,  6.0177e-01,
            3.9066e-02,  6.5587e-01],
          [-3.0696e-01,  2.9053e-01,  3.3650e+00,  ...,  1.5835e-01,
           -1.4796e+00,  1.3181e+00],
          ...,
          [-8.1758e-01,  6.0177e-01,  1.5835e-01,  ...,  5.7955e+00,
            5.6379e-01,  1.7561e+00],
          [ 1.5890e-01,  3.9066e-02, -1.4796e+00,  ...,  5.6379e-01,
            3.9829e+00,  1.4419e+00],
          [-6.9397e-01,  6.5587e-01,  1.3181e+00,  ...,  1.7561e+00,
            1.4419e+00,  6.1425e+00]]],


        [[[ 1.1688e-01, -4.3430e-02, -3.4938e-01,  ..., -9.0453e-02,
           -9.7757e-02,  4.2786e-01],
          [-4.3430e-02,  1.4602e+01,  3.0679e-01,  ...,  2.2611e+00,
           -3.9131e+00,  1.4461e+00],
          [-3.4938e-01,  3.0679e-01,  2.5500e+00,  ...,  9.4533e-01,
            2.4027e-01, -2.6605e+00],
          ...,
          [-9.0453e-02,  2.2611e+00,  9.4533e-01,  ...,  3.4779e+00,
            1.4063e+00, -1.1856e+00],
          [-9.7757e-02, -3.9131e+00,  2.4027e-01,  ...,  1.4063e+00,
            8.2637e+00, -3.3714e+00],
          [ 4.2786e-01,  1.4461e+00, -2.6605e+00,  ..., -1.1856e+00,
           -3.3714e+00,  1.1263e+01]],

         [[ 4.6136e-02,  2.5567e-02,  4.0251e-03,  ...,  6.4500e-02,
           -6.0037e-02,  2.3629e-01],
          [ 2.5567e-02,  1.3122e+01, -1.7563e+00,  ..., -2.9405e+00,
           -5.7177e+00, -2.7299e+00],
          [ 4.0251e-03, -1.7563e+00,  3.6605e-01,  ...,  5.1704e-01,
            6.9325e-01,  1.1055e-01],
          ...,
          [ 6.4500e-02, -2.9405e+00,  5.1704e-01,  ...,  4.8269e+00,
            8.8747e-01, -1.1736e+00],
          [-6.0037e-02, -5.7177e+00,  6.9325e-01,  ...,  8.8747e-01,
            1.6299e+01, -4.1514e+00],
          [ 2.3629e-01, -2.7299e+00,  1.1055e-01,  ..., -1.1736e+00,
           -4.1514e+00,  2.5460e+01]],

         [[ 1.4042e-01, -3.7308e-01, -5.0856e-01,  ..., -4.0750e-01,
            4.9229e-02,  5.3975e-01],
          [-3.7308e-01,  3.6513e+00,  1.9861e+00,  ...,  2.7890e+00,
           -5.8212e-01,  2.8093e-01],
          [-5.0856e-01,  1.9861e+00,  5.3247e+00,  ...,  2.0116e+00,
           -7.0162e-01, -3.1725e+00],
          ...,
          [-4.0750e-01,  2.7890e+00,  2.0116e+00,  ...,  8.1440e+00,
            2.3761e+00, -8.0531e-01],
          [ 4.9229e-02, -5.8212e-01, -7.0162e-01,  ...,  2.3761e+00,
            8.1572e+00,  8.9326e-01],
          [ 5.3975e-01,  2.8093e-01, -3.1725e+00,  ..., -8.0531e-01,
            8.9326e-01,  1.4333e+01]],

         ...,

         [[ 2.9240e-01,  1.1447e-01, -3.7618e-01,  ...,  3.5650e-02,
           -3.3617e-01,  2.8714e-01],
          [ 1.1447e-01,  1.6417e+01, -2.0054e-01,  ...,  1.9094e+00,
           -5.0753e+00,  1.2606e+00],
          [-3.7618e-01, -2.0054e-01,  2.6578e+00,  ...,  9.5666e-01,
            4.2203e-01, -1.5699e+00],
          ...,
          [ 3.5650e-02,  1.9094e+00,  9.5666e-01,  ...,  4.0178e+00,
            1.0363e-02, -2.8129e-01],
          [-3.3617e-01, -5.0753e+00,  4.2203e-01,  ...,  1.0363e-02,
            5.8713e+00, -2.2079e+00],
          [ 2.8714e-01,  1.2606e+00, -1.5699e+00,  ..., -2.8129e-01,
           -2.2079e+00,  6.7690e+00]],

         [[ 4.6503e-02, -3.4086e-01,  8.4893e-02,  ..., -2.2093e-01,
           -7.6180e-02,  2.2865e-01],
          [-3.4086e-01,  2.7629e+00, -5.8479e-01,  ...,  1.3492e+00,
           -5.4946e-02, -1.6021e+00],
          [ 8.4893e-02, -5.8479e-01,  4.8533e-01,  ..., -1.0027e+00,
           -5.1827e-01,  4.5284e-01],
          ...,
          [-2.2093e-01,  1.3492e+00, -1.0027e+00,  ...,  9.8429e+00,
            3.2738e+00,  9.8488e-01],
          [-7.6180e-02, -5.4946e-02, -5.1827e-01,  ...,  3.2738e+00,
            1.5826e+01, -1.5231e+00],
          [ 2.2865e-01, -1.6021e+00,  4.5284e-01,  ...,  9.8488e-01,
           -1.5231e+00,  1.4260e+01]],

         [[ 1.6465e-01, -3.4863e-01, -1.9508e-01,  ..., -2.9138e-01,
           -1.5552e-03,  3.3099e-01],
          [-3.4863e-01,  2.2164e+00,  5.2878e-01,  ...,  7.9016e-01,
           -6.6402e-01, -3.0166e-01],
          [-1.9508e-01,  5.2878e-01,  1.3335e+00,  ...,  2.3243e-01,
           -5.0645e-01, -6.8140e-01],
          ...,
          [-2.9138e-01,  7.9016e-01,  2.3243e-01,  ...,  3.0426e+00,
            1.3700e+00, -1.0224e+00],
          [-1.5552e-03, -6.6402e-01, -5.0645e-01,  ...,  1.3700e+00,
            4.9458e+00, -1.0143e+00],
          [ 3.3099e-01, -3.0166e-01, -6.8140e-01,  ..., -1.0224e+00,
           -1.0143e+00,  6.2951e+00]]],


        ...,


        [[[ 5.3068e-01, -1.4441e-01, -8.3369e-01,  ..., -3.0402e-01,
           -8.5341e-02,  3.7348e-01],
          [-1.4441e-01,  7.1730e+00,  7.6152e-01,  ...,  2.4617e+00,
           -9.8194e-01,  1.7416e+00],
          [-8.3369e-01,  7.6152e-01,  6.1450e+00,  ...,  1.9786e+00,
           -4.3510e-01, -1.8541e+00],
          ...,
          [-3.0402e-01,  2.4617e+00,  1.9786e+00,  ...,  4.8916e+00,
            1.1973e+00,  6.3253e-01],
          [-8.5341e-02, -9.8194e-01, -4.3510e-01,  ...,  1.1973e+00,
            3.2906e+00,  4.8253e-01],
          [ 3.7348e-01,  1.7416e+00, -1.8541e+00,  ...,  6.3253e-01,
            4.8253e-01,  6.0983e+00]],

         [[ 6.5240e+00, -3.4259e-01, -2.2096e+00,  ..., -9.9583e-01,
           -6.3553e-01, -2.5876e+00],
          [-3.4259e-01,  2.2927e+00,  6.3602e-01,  ...,  1.6647e+00,
            3.3142e-01,  1.4779e+00],
          [-2.2096e+00,  6.3602e-01,  2.4486e+01,  ...,  3.9935e+00,
           -2.3401e+00,  2.8967e+00],
          ...,
          [-9.9583e-01,  1.6647e+00,  3.9935e+00,  ...,  2.0055e+01,
            1.3692e+00,  2.8023e+00],
          [-6.3553e-01,  3.3142e-01, -2.3401e+00,  ...,  1.3692e+00,
            4.8282e+00,  2.5330e+00],
          [-2.5876e+00,  1.4779e+00,  2.8967e+00,  ...,  2.8023e+00,
            2.5330e+00,  1.0935e+01]],

         [[ 3.3055e-01, -1.5004e-01, -4.8994e-01,  ..., -1.9955e-01,
            9.6798e-02,  3.4073e-01],
          [-1.5004e-01,  5.3194e+00,  2.1820e-01,  ...,  8.3328e-01,
           -7.1422e-01,  3.2361e-01],
          [-4.8994e-01,  2.1820e-01,  2.3079e+00,  ...,  8.6850e-01,
           -6.3268e-01, -1.1338e+00],
          ...,
          [-1.9955e-01,  8.3328e-01,  8.6850e-01,  ...,  3.4121e+00,
            5.0636e-01, -9.6485e-01],
          [ 9.6798e-02, -7.1422e-01, -6.3268e-01,  ...,  5.0636e-01,
            3.8198e+00, -1.5354e-02],
          [ 3.4073e-01,  3.2361e-01, -1.1338e+00,  ..., -9.6485e-01,
           -1.5354e-02,  9.2886e+00]],

         ...,

         [[ 3.2663e-01,  2.5215e-02, -6.8358e-01,  ..., -1.1804e-01,
           -4.9571e-02,  4.4468e-01],
          [ 2.5215e-02,  1.5116e+01,  2.2194e-01,  ...,  2.9119e+00,
           -2.1794e+00,  1.8463e+00],
          [-6.8358e-01,  2.2194e-01,  4.2750e+00,  ...,  1.5888e+00,
           -1.5476e-01, -2.4531e+00],
          ...,
          [-1.1804e-01,  2.9119e+00,  1.5888e+00,  ...,  4.8700e+00,
            1.0831e+00, -3.9611e-01],
          [-4.9571e-02, -2.1794e+00, -1.5476e-01,  ...,  1.0831e+00,
            4.1981e+00, -6.2507e-01],
          [ 4.4468e-01,  1.8463e+00, -2.4531e+00,  ..., -3.9611e-01,
           -6.2507e-01,  8.8990e+00]],

         [[ 3.1404e+00,  8.8478e-02, -1.3196e+00,  ..., -2.9387e-01,
           -4.0089e-01, -1.2615e+00],
          [ 8.8478e-02,  4.0128e+00,  2.8375e-01,  ...,  1.4456e+00,
           -1.7300e-01,  1.0208e+00],
          [-1.3196e+00,  2.8375e-01,  1.0115e+01,  ...,  2.4882e+00,
           -1.3056e+00,  1.0078e+00],
          ...,
          [-2.9387e-01,  1.4456e+00,  2.4882e+00,  ...,  9.2826e+00,
           -1.0373e-01,  6.2864e-01],
          [-4.0089e-01, -1.7300e-01, -1.3056e+00,  ..., -1.0373e-01,
            2.7372e+00,  1.6051e+00],
          [-1.2615e+00,  1.0208e+00,  1.0078e+00,  ...,  6.2864e-01,
            1.6051e+00,  5.5887e+00]],

         [[ 1.6410e+00, -1.8705e-01, -1.7080e-01,  ..., -1.7080e-01,
           -7.0236e-01, -8.9657e-01],
          [-1.8705e-01,  2.1557e+00,  1.1993e-01,  ...,  3.5611e-01,
           -8.2990e-01,  5.0009e-01],
          [-1.7080e-01,  1.1993e-01,  4.2247e+00,  ...,  7.4028e-01,
           -9.4505e-01,  7.6899e-01],
          ...,
          [-1.7080e-01,  3.5611e-01,  7.4028e-01,  ...,  4.8397e+00,
            6.9999e-01, -4.0617e-02],
          [-7.0236e-01, -8.2990e-01, -9.4505e-01,  ...,  6.9999e-01,
            3.6925e+00,  8.3948e-01],
          [-8.9657e-01,  5.0009e-01,  7.6899e-01,  ..., -4.0617e-02,
            8.3948e-01,  3.1297e+00]]],


        [[[ 1.0435e-01, -1.5680e-01, -1.3139e-02,  ..., -8.6163e-02,
           -1.1744e-01,  2.4690e-01],
          [-1.5680e-01,  3.0945e+00, -2.0342e-01,  ..., -4.2523e-01,
           -1.8642e+00, -5.8446e-01],
          [-1.3139e-02, -2.0342e-01,  4.9515e-01,  ...,  2.0724e-02,
           -3.9217e-02, -2.7887e-01],
          ...,
          [-8.6163e-02, -4.2523e-01,  2.0724e-02,  ...,  1.4535e+00,
            1.5016e+00, -7.3917e-01],
          [-1.1744e-01, -1.8642e+00, -3.9217e-02,  ...,  1.5016e+00,
            7.9352e+00, -2.8087e+00],
          [ 2.4690e-01, -5.8446e-01, -2.7887e-01,  ..., -7.3917e-01,
           -2.8087e+00,  7.4959e+00]],

         [[ 1.0626e-01, -5.3499e-01, -9.6329e-02,  ..., -4.1398e-01,
            2.1991e-01,  3.2964e-01],
          [-5.3499e-01,  2.9266e+00,  5.3086e-01,  ...,  1.9491e+00,
           -1.0397e+00, -1.5764e+00],
          [-9.6329e-02,  5.3086e-01,  5.6259e-01,  ..., -1.3441e-01,
           -7.6260e-01, -1.8934e-01],
          ...,
          [-4.1398e-01,  1.9491e+00, -1.3441e-01,  ...,  1.2002e+01,
            1.6773e+00, -1.1184e+00],
          [ 2.1991e-01, -1.0397e+00, -7.6260e-01,  ...,  1.6773e+00,
            1.0893e+01,  2.6986e-01],
          [ 3.2964e-01, -1.5764e+00, -1.8934e-01,  ..., -1.1184e+00,
            2.6986e-01,  1.8279e+01]],

         [[ 2.0433e-01, -6.0532e-01, -6.0465e-01,  ..., -6.2567e-01,
            8.8886e-02,  5.7080e-01],
          [-6.0532e-01,  2.8582e+00,  2.3388e+00,  ...,  3.0698e+00,
           -2.6689e-01, -3.6055e-01],
          [-6.0465e-01,  2.3388e+00,  7.2763e+00,  ...,  2.1341e+00,
           -1.0772e+00, -2.3935e+00],
          ...,
          [-6.2567e-01,  3.0698e+00,  2.1341e+00,  ...,  1.2192e+01,
            2.1627e+00,  6.7769e-01],
          [ 8.8886e-02, -2.6689e-01, -1.0772e+00,  ...,  2.1627e+00,
            8.0345e+00,  3.2684e+00],
          [ 5.7080e-01, -3.6055e-01, -2.3935e+00,  ...,  6.7769e-01,
            3.2684e+00,  1.6359e+01]],

         ...,

         [[ 1.4648e-01, -1.1491e-01, -2.7340e-01,  ..., -1.2295e-01,
           -6.9424e-02,  3.6800e-01],
          [-1.1491e-01,  6.7244e+00,  2.1170e-01,  ...,  8.1869e-01,
           -2.1934e+00,  3.3195e-01],
          [-2.7340e-01,  2.1170e-01,  1.6448e+00,  ...,  5.7570e-01,
           -7.5520e-02, -1.4958e+00],
          ...,
          [-1.2295e-01,  8.1869e-01,  5.7570e-01,  ...,  2.1348e+00,
            1.0615e+00, -1.4271e+00],
          [-6.9424e-02, -2.1934e+00, -7.5520e-02,  ...,  1.0615e+00,
            5.6863e+00, -2.4001e+00],
          [ 3.6800e-01,  3.3195e-01, -1.4958e+00,  ..., -1.4271e+00,
           -2.4001e+00,  7.7843e+00]],

         [[ 7.7600e-01, -2.0639e-02, -7.2864e-01,  ..., -1.9157e-01,
           -6.1336e-01,  9.7962e-02],
          [-2.0639e-02,  7.7400e+00,  6.6602e-01,  ...,  2.5624e+00,
           -2.4853e+00,  2.1299e+00],
          [-7.2864e-01,  6.6602e-01,  8.0750e+00,  ...,  2.0477e+00,
            1.0824e-01, -1.2845e+00],
          ...,
          [-1.9157e-01,  2.5624e+00,  2.0477e+00,  ...,  6.5405e+00,
            7.9761e-01,  1.7518e+00],
          [-6.1336e-01, -2.4853e+00,  1.0824e-01,  ...,  7.9761e-01,
            6.5287e+00,  7.4779e-01],
          [ 9.7962e-02,  2.1299e+00, -1.2845e+00,  ...,  1.7518e+00,
            7.4779e-01,  7.5688e+00]],

         [[ 7.1556e-01, -3.6847e-01, -2.9601e-01,  ..., -3.3812e-01,
           -3.0292e-01, -7.9734e-02],
          [-3.6847e-01,  2.1026e+00,  3.2322e-01,  ...,  6.1110e-01,
           -6.4078e-01,  5.9089e-01],
          [-2.9601e-01,  3.2322e-01,  3.1993e+00,  ...,  5.0087e-01,
           -7.0823e-01,  1.5203e-01],
          ...,
          [-3.3812e-01,  6.1110e-01,  5.0087e-01,  ...,  2.8667e+00,
            8.6047e-01,  4.6084e-01],
          [-3.0292e-01, -6.4078e-01, -7.0823e-01,  ...,  8.6047e-01,
            2.9318e+00,  3.8269e-01],
          [-7.9734e-02,  5.9089e-01,  1.5203e-01,  ...,  4.6084e-01,
            3.8269e-01,  2.6491e+00]]],


        [[[ 9.9086e-01, -1.3998e+00, -3.6653e-01,  ..., -1.2236e+00,
           -1.6737e-01, -1.5520e-01],
          [-1.3998e+00,  2.1872e+00,  7.3765e-01,  ...,  2.0230e+00,
            2.3221e-01,  6.8362e-01],
          [-3.6653e-01,  7.3765e-01,  7.2464e+00,  ..., -4.2438e-01,
           -1.7580e+00,  1.9572e+00],
          ...,
          [-1.2236e+00,  2.0230e+00, -4.2438e-01,  ...,  1.4082e+01,
            2.8091e+00,  4.8819e+00],
          [-1.6737e-01,  2.3221e-01, -1.7580e+00,  ...,  2.8091e+00,
            6.3038e+00,  3.6015e+00],
          [-1.5520e-01,  6.8362e-01,  1.9572e+00,  ...,  4.8819e+00,
            3.6015e+00,  1.1810e+01]],

         [[ 4.3665e-01,  2.6829e-01, -2.6143e-01,  ...,  2.0806e-01,
           -2.7577e-01,  7.1315e-02],
          [ 2.6829e-01,  1.4754e+01, -9.8355e-01,  ...,  1.0828e-01,
           -4.0829e+00, -6.9069e-01],
          [-2.6143e-01, -9.8355e-01,  1.4390e+00,  ...,  7.1068e-01,
            1.4734e-01, -5.3480e-01],
          ...,
          [ 2.0806e-01,  1.0828e-01,  7.1068e-01,  ...,  4.7903e+00,
           -1.0500e+00, -1.4510e+00],
          [-2.7577e-01, -4.0829e+00,  1.4734e-01,  ..., -1.0500e+00,
            3.8879e+00, -3.4604e-01],
          [ 7.1315e-02, -6.9069e-01, -5.3480e-01,  ..., -1.4510e+00,
           -3.4604e-01,  6.3191e+00]],

         [[ 1.0422e-01, -4.1225e-01,  1.9011e-01,  ..., -2.3172e-01,
            4.6882e-02,  1.4920e-01],
          [-4.1225e-01,  1.8831e+00, -8.4669e-01,  ...,  3.8836e-01,
           -4.9545e-01, -8.2438e-01],
          [ 1.9011e-01, -8.4669e-01,  5.6561e-01,  ..., -5.5569e-01,
           -1.5508e-01,  5.6516e-01],
          ...,
          [-2.3172e-01,  3.8836e-01, -5.5569e-01,  ...,  7.6742e+00,
            3.5744e+00,  1.9458e+00],
          [ 4.6882e-02, -4.9545e-01, -1.5508e-01,  ...,  3.5744e+00,
            1.2275e+01,  1.3355e+00],
          [ 1.4920e-01, -8.2438e-01,  5.6516e-01,  ...,  1.9458e+00,
            1.3355e+00,  1.3045e+01]],

         ...,

         [[ 1.5192e-01, -2.9293e-01, -2.1621e-01,  ..., -2.5053e-01,
            2.6970e-02,  3.4687e-01],
          [-2.9293e-01,  2.5758e+00,  4.8654e-01,  ...,  6.5381e-01,
           -7.9741e-01, -3.1284e-01],
          [-2.1621e-01,  4.8654e-01,  1.2677e+00,  ...,  3.3996e-01,
           -4.7447e-01, -8.7797e-01],
          ...,
          [-2.5053e-01,  6.5381e-01,  3.3996e-01,  ...,  2.8251e+00,
            1.3353e+00, -1.4267e+00],
          [ 2.6970e-02, -7.9741e-01, -4.7447e-01,  ...,  1.3353e+00,
            5.1867e+00, -1.2575e+00],
          [ 3.4687e-01, -3.1284e-01, -8.7797e-01,  ..., -1.4267e+00,
           -1.2575e+00,  7.4440e+00]],

         [[ 5.7047e-01, -4.1308e-01, -1.9781e-01,  ..., -3.4726e-01,
           -3.0805e-01, -8.6290e-03],
          [-4.1308e-01,  1.8525e+00,  3.0749e-01,  ...,  5.7121e-01,
           -6.1551e-01,  5.0626e-01],
          [-1.9781e-01,  3.0749e-01,  2.7295e+00,  ...,  2.5951e-01,
           -6.8314e-01,  1.3950e-01],
          ...,
          [-3.4726e-01,  5.7121e-01,  2.5951e-01,  ...,  2.6925e+00,
            1.0987e+00,  6.1820e-01],
          [-3.0805e-01, -6.1551e-01, -6.8314e-01,  ...,  1.0987e+00,
            3.5841e+00,  2.8193e-01],
          [-8.6290e-03,  5.0626e-01,  1.3950e-01,  ...,  6.1820e-01,
            2.8193e-01,  3.0224e+00]],

         [[ 6.7950e-02, -2.0289e-02, -3.0765e-01,  ..., -6.0706e-02,
            8.0836e-03,  4.2396e-01],
          [-2.0289e-02,  2.1411e+01, -1.9893e-01,  ...,  2.0437e+00,
           -4.1471e+00,  8.5025e-01],
          [-3.0765e-01, -1.9893e-01,  2.2021e+00,  ...,  7.8810e-01,
            3.1163e-03, -3.1644e+00],
          ...,
          [-6.0706e-02,  2.0437e+00,  7.8810e-01,  ...,  4.7559e+00,
            2.2746e+00, -2.7648e+00],
          [ 8.0836e-03, -4.1471e+00,  3.1163e-03,  ...,  2.2746e+00,
            1.1284e+01, -4.3158e+00],
          [ 4.2396e-01,  8.5025e-01, -3.1644e+00,  ..., -2.7648e+00,
           -4.3158e+00,  2.3829e+01]]]], grad_fn=<ExpandBackward0>)