# Survival kit Torch Distribution

In [2]:
import torch
import gpytorch
import torch.distributions as dist
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import Normal, Independent

from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
from gpytorch.models import ExactGP
from gpytorch.mlls import ExactMarginalLogLikelihood

Article

https://bochang.me/blog/posts/pytorch-distributions/

Doc Torch

https://docs.pytorch.org/docs/1.3.0/distributions.html#multivariatenormal

**Paramètres de la distribution:**

event_shape : dimension de la variable

batch_shape : batch size, indépendent, non identiques

**Arguments de rsample:**

sample_shape : nombre d'échantillons iid

In [3]:
# Crée une distribution MVN
# - covariance_matrix est de dimension 2x2, donc la dimension de la variable aléatoire est 2 : event_shape = 2
# - loc est le vecteur de moyenne de dimension 2

mvn = MultivariateNormal(
    loc=torch.zeros(2),
    covariance_matrix=torch.tensor([[1.0, 0.5], [0.5, 1.0]])
)

samples = mvn.rsample((1000,))  # Génère 1000 échantillons

print(samples.shape)  # Affiche la forme des échantillons

torch.Size([1000, 2])


In [4]:
# crée une autre distribution MVN
# - covariance_matrix toujours 2x2 => event_shape = 2
# - loc déduit du tensor en input : shape (5, 3, 2) + event_shape = 2 => batch_shape = (5, 3)

mvn = MultivariateNormal(
    loc=torch.randn(5,3,2),
    covariance_matrix=torch.eye(2)
)

samples = mvn.rsample((10,))  # Génère 10 échantillons

print(samples.shape)  # Affiche la forme des échantillons

torch.Size([10, 5, 3, 2])


In [5]:
# crée une autre distribution MVN
# - covariance_matrix toujours 2x2 => event_shape = 2
# - loc déduit du tensor en input : shape (5, 3, 2) + event_shape = 2 => batch_shape = (5, 3)

try:
    mvn = MultivariateNormal(
        loc=torch.randn(5,3,4),
        covariance_matrix=torch.eye(2)
    )

    samples = mvn.rsample((10,))  # Génère 10 échantillons
except RuntimeError as e:
    print(f"Non correspondance entre la taille de la matrice de covariance et la dernière dimension de loc")

Non correspondance entre la taille de la matrice de covariance et la dernière dimension de loc


### Independent class

The Independent class does not represent any probability distribution. Instead, it creates a new distribution instance by “reinterpreting” some of the batch shapes of an existing distribution as event shapes.

torch.distributions.independent.Independent(base_distribution,
reinterpreted_batch_ndims, validate_args=None)

In [6]:

normal = Normal(
    loc = torch.randn(5,3,2),
    scale = torch.ones(5,3,2)
)

print(f"AVANT Independent:")
print(f"Caractéristiques de la distribution normale : batch_shape = {normal.batch_shape}, event_shape = {normal.event_shape}")

# use Independent pour recaster la dernière dimension de 'batch_shape' dans 'event_shape' 
independent_normal_1 = Independent(normal, 1) 
print(f"APRES Independent(normal, 1):")
print(f"Caractéristiques de la distribution normale indépendante : batch_shape = {independent_normal_1.batch_shape}, event_shape = {independent_normal_1.event_shape}")

# use Independent pour recaster les deux dernières dimensions de 'batch_shape' dans 'event_shape' 
independent_normal_2 = Independent(normal, 2) 
print(f"APRES Independent(normal, 2):")
print(f"Caractéristiques de la distribution normale indépendante : batch_shape = {independent_normal_2.batch_shape}, event_shape = {independent_normal_2.event_shape}")

AVANT Independent:
Caractéristiques de la distribution normale : batch_shape = torch.Size([5, 3, 2]), event_shape = torch.Size([])
APRES Independent(normal, 1):
Caractéristiques de la distribution normale indépendante : batch_shape = torch.Size([5, 3]), event_shape = torch.Size([2])
APRES Independent(normal, 2):
Caractéristiques de la distribution normale indépendante : batch_shape = torch.Size([5]), event_shape = torch.Size([3, 2])


# For GP-VAE

In [7]:
T = 50 # sequence length
D_z = 8 # z dimension
B = 16 # batch size

mvn = MultivariateNormal(
    loc=torch.zeros(T, B, D_z),
    covariance_matrix=torch.eye(D_z)
)

print(f"Caractéristiques de la distribution MVN : batch_shape = {mvn.batch_shape}, event_shape = {mvn.event_shape}")

# sampling one sample
print(f"Génère 1 échantillon")
sample = mvn.rsample((1,))  # Génère un échantillon
print(f"Shape de l'échantillon : {sample.shape}")  # Affiche la forme de l'échantillon

# sampling K samples
K = 10
print(f"Génère {K} échantillons")
samples = mvn.rsample((K,))  # Génère K échantillons
print(f"Shape des échantillons : {samples.shape}")  # Affiche la forme des échantillons

Caractéristiques de la distribution MVN : batch_shape = torch.Size([50, 16]), event_shape = torch.Size([8])
Génère 1 échantillon
Shape de l'échantillon : torch.Size([1, 50, 16, 8])
Génère 10 échantillons
Shape des échantillons : torch.Size([10, 50, 16, 8])


### Processus Gaussien avec GPyTorch

Use case :

- y_train : on a échantilloné **un** sample $z_{1:T}$ de dimension $D_z$ : tensor shape $B \times T \times D_z$
- x_train : instants $(t_i)_{i=1,...T}$ : tensor shape $B \times T \times 1$

In [33]:
T = 50 # sequence length
D_z = 8 # z dimension
B = 16 # batch size

print(f"Batch size = {B}, sequence length = {T}, z dimension = {D_z}")

Batch size = 16, sequence length = 50, z dimension = 8


In [68]:
# TARGETS BRUTES
# ---------------

# train_x est un batch de B séquences temporelles de longueur T
train_x = torch.randn((B, T, 1))  # B x T x 1
# la target est un sample d'un batch de B séquences temporelles de longueur T de variables de dimension D_z
train_y = torch.randn((B, T, D_z))   # B x T x D_z

# RESHAPES
# --------

# duplicate train_x along the last dimension to match the output dimension D_z
train_x = train_x.expand(-1, -1, D_z)  # B x T x D_z
# permute les deux dernières dimensions pour avoir la forme (B, D_z, T)
train_x = train_x.permute(0, 2, 1)  # B x D_z x T
train_y = train_y.permute(0, 2, 1)  # B x D_z x T

# batch augmenté en B x D_z
train_x = train_x.reshape(B * D_z, T, 1)  # (B * D_z) x T
train_y = train_y.reshape(B * D_z, T)  # (B * D_z) x T
# batch_shape = (B * D_z)
bs = B * D_z  # B x D_z

# REPORT
# ------

print(f"train_x shape: {train_x.shape}")
print(f"train_y shape: {train_y.shape}")
print(f"batch_shape: {bs}")

train_x shape: torch.Size([128, 50, 1])
train_y shape: torch.Size([128, 50])
batch_shape: 128


In [None]:
# We will use the simplest form of GP model, exact inference

#--------------------------------------------
#   BATCH VERSION
#--------------------------------------------

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        
        """An __init__ method that takes the training data and a likelihood, 
        and constructs whatever objects are necessary for the model"s forward method. 
        This will most commonly include things like a mean module and a kernel module.
        
        Inputs:
        - train_x: the training data, a tensor of shape (n, d) where n is the number of training points and d is the input dimension.
        - train_y: the training targets, a tensor of shape (n,) or (n, 1) where n is the number of training points.
        - likelihood: a likelihood object, typically a GaussianLikelihood or FixedNoise GaussianLikelihood.
        """
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        # constant mean as prior mean
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([bs]))
        # RBF kernel as prior covariance
        # first, specify a RBF kernel with a SmoothedBoxPrior on the lengthscale (approx Uniform)
        self.rbf = gpytorch.kernels.RBFKernel(batch_shape=torch.Size([bs]), lengthscale_prior=gpytorch.priors.SmoothedBoxPrior(1e-3, 1e+3, sigma=1e-2))
        # then, scale the RBF kernel with a SmoothedBoxPrior on the outputscale (approx Uniform)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.rbf, batch_shape=torch.Size([bs]), outputscale_prior=gpytorch.priors.SmoothedBoxPrior(1e-3, 1e+3, sigma=1e-2))

    def forward(self, x):
        """A forward method that takes in some n x d data x and returns a MultivariateNormal with the prior mean and covariance evaluated at x. 
        In other words, we return the vector mu(x) and the n x n matrix representing the prior mean and covariance matrix of the GP.
        """
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
    
    
likelihood = GaussianLikelihood()

# # Instantiate the model
model = ExactGPModel(train_x, train_y, likelihood)

In [79]:
sampled_z = torch.randn((B, T, D_z))  # Input data for prediction
sampled_z = sampled_z.permute(0, 2, 1)  # Reshape to (B, D_z, T)
sampled_z = sampled_z.reshape(B * D_z, T, 1)  # Reshape to (B * D_z, T, 1)

print(f"sampled_z shape: {sampled_z.shape}")

model.eval()  # Set the model to evaluation mode
likelihood.eval()  # Set the likelihood to evaluation mode

gp_output = model(sampled_z)  # Forward pass through the model

sampled_z shape: torch.Size([128, 50, 1])


In [92]:
print(f"Caractéristiques de la sortie du modèle GP :")
print(gp_output)  # Print the output of the GP model
print(f"Mean shape: {gp_output.mean.shape}, Covariance shape: {gp_output.covariance_matrix.shape}")
print(f"Batch shape: {gp_output.batch_shape}, Event shape: {gp_output.event_shape}")

Caractéristiques de la sortie du modèle GP :
MultivariateNormal(loc: torch.Size([128, 50]), covariance_matrix: torch.Size([128, 50, 50]))
Mean shape: torch.Size([128, 50]), Covariance shape: torch.Size([128, 50, 50])
Batch shape: torch.Size([128]), Event shape: torch.Size([50])


In [96]:
mean_r = gp_output.mean.view(B, D_z, T)  # Reshape the mean to (B, D_z, T)
covar_r = gp_output.covariance_matrix.view(B, D_z, T, T)  # Reshape the covariance matrix to (B, D_z, T, T)

print(f"Mean reshaped: {mean_r.shape}, Covariance reshaped: {covar_r.shape}")

mnv_r = MultivariateNormal(
    loc=mean_r,
    covariance_matrix=covar_r
)

print(f"Caractéristiques de la distribution MVN résultante :")
print(f"Batch shape: {mvn_r.batch_shape}, Event shape: {mvn_r.event_shape}")

Mean reshaped: torch.Size([16, 8, 50]), Covariance reshaped: torch.Size([16, 8, 50, 50])


ValueError: Expected parameter covariance_matrix (Tensor of shape (16, 8, 50, 50)) of distribution MultivariateNormal(loc: torch.Size([16, 8, 50]), covariance_matrix: torch.Size([16, 8, 50, 50])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[[ 1.2246e-01,  1.3547e-01, -5.6204e-04,  ..., -1.0692e-02,
            2.2334e-03,  6.1861e-03],
          [ 1.3547e-01,  1.5537e-01, -1.8775e-04,  ..., -1.0547e-02,
            2.7847e-03,  6.7129e-03],
          [-5.6204e-04, -1.8775e-04,  1.6825e-01,  ...,  2.2143e-03,
            2.0643e-01, -1.2879e-02],
          ...,
          [-1.0692e-02, -1.0547e-02,  2.2143e-03,  ...,  3.7139e-02,
            4.9091e-03,  7.0602e-03],
          [ 2.2334e-03,  2.7847e-03,  2.0643e-01,  ...,  4.9091e-03,
            4.6627e-01, -1.6280e-02],
          [ 6.1860e-03,  6.7129e-03, -1.2879e-02,  ...,  7.0602e-03,
           -1.6280e-02,  5.7898e-02]],

         [[ 4.3903e-02, -5.2263e-03,  8.3811e-03,  ..., -1.7140e-03,
            4.1678e-03, -4.8693e-03],
          [-5.2263e-03,  5.0108e-02,  1.3180e-02,  ...,  1.6240e-03,
            4.9014e-02,  3.7199e-02],
          [ 8.3811e-03,  1.3181e-02,  3.9436e-02,  ..., -1.2515e-03,
           -1.7125e-02,  2.9260e-02],
          ...,
          [-1.7138e-03,  1.6240e-03, -1.2514e-03,  ...,  7.8040e-02,
           -2.3914e-03,  2.2829e-03],
          [ 4.1678e-03,  4.9014e-02, -1.7125e-02,  ..., -2.3914e-03,
            1.5370e-01,  7.9709e-03],
          [-4.8692e-03,  3.7199e-02,  2.9260e-02,  ...,  2.2829e-03,
            7.9709e-03,  4.0783e-02]],

         [[ 5.3541e-02,  7.5480e-03, -6.9608e-03,  ..., -7.8890e-03,
            6.8551e-03,  1.0808e-02],
          [ 7.5481e-03,  2.0158e-01, -1.8868e-02,  ..., -8.5916e-03,
            1.5234e-01, -1.0136e-02],
          [-6.9607e-03, -1.8868e-02,  4.0163e-02,  ...,  3.7386e-02,
           -1.0238e-02,  2.6794e-02],
          ...,
          [-7.8889e-03, -8.5916e-03,  3.7386e-02,  ...,  4.0416e-02,
            4.8429e-03,  1.6972e-02],
          [ 6.8551e-03,  1.5234e-01, -1.0238e-02,  ...,  4.8429e-03,
            1.3365e-01, -1.1925e-02],
          [ 1.0808e-02, -1.0136e-02,  2.6794e-02,  ...,  1.6972e-02,
           -1.1925e-02,  3.7177e-02]],

         ...,

         [[ 3.8732e-02,  3.5558e-02,  1.9447e-02,  ...,  3.8706e-02,
            7.5677e-03,  3.8965e-02],
          [ 3.5557e-02,  3.7167e-02,  7.8052e-03,  ...,  3.5617e-02,
            1.7047e-02,  3.3604e-02],
          [ 1.9447e-02,  7.8053e-03,  5.8393e-02,  ...,  1.9223e-02,
           -7.9518e-03,  2.5534e-02],
          ...,
          [ 3.8706e-02,  3.5617e-02,  1.9223e-02,  ...,  3.8683e-02,
            7.7267e-03,  3.8898e-02],
          [ 7.5676e-03,  1.7047e-02, -7.9518e-03,  ...,  7.7267e-03,
            4.0321e-02,  3.5469e-03],
          [ 3.8965e-02,  3.3604e-02,  2.5534e-02,  ...,  3.8898e-02,
            3.5469e-03,  4.0301e-02]],

         [[ 8.7728e-02,  2.5207e-03, -3.8425e-03,  ...,  2.5477e-03,
           -1.2340e-03, -7.2367e-04],
          [ 2.5207e-03,  4.0311e-02,  1.9010e-02,  ...,  3.9893e-02,
            1.0466e-02,  3.0007e-02],
          [-3.8425e-03,  1.9010e-02,  3.7142e-02,  ...,  1.5152e-02,
           -8.7951e-03,  3.4557e-02],
          ...,
          [ 2.5477e-03,  3.9893e-02,  1.5152e-02,  ...,  4.0384e-02,
            1.7058e-02,  2.6407e-02],
          [-1.2340e-03,  1.0465e-02, -8.7951e-03,  ...,  1.7058e-02,
            8.5256e-02, -8.2067e-03],
          [-7.2366e-04,  3.0007e-02,  3.4557e-02,  ...,  2.6407e-02,
           -8.2067e-03,  3.8183e-02]],

         [[ 1.3563e-01, -5.8711e-03, -9.8941e-04,  ...,  3.4617e-04,
           -1.1554e-03, -2.8176e-03],
          [-5.8710e-03,  4.2726e-02,  3.5075e-02,  ..., -5.1057e-03,
            1.7242e-03,  3.9029e-02],
          [-9.8941e-04,  3.5075e-02,  3.7747e-02,  ...,  1.6570e-03,
           -6.7142e-03,  3.7622e-02],
          ...,
          [ 3.4617e-04, -5.1056e-03,  1.6570e-03,  ...,  4.2200e-02,
            2.5146e-02, -1.6888e-03],
          [-1.1554e-03,  1.7243e-03, -6.7141e-03,  ...,  2.5146e-02,
            1.3184e-01, -3.1389e-03],
          [-2.8175e-03,  3.9029e-02,  3.7623e-02,  ..., -1.6888e-03,
           -3.1389e-03,  3.9170e-02]]],


        [[[ 2.0510e-01, -1.0087e-02,  7.3707e-04,  ...,  2.8749e-03,
            1.9862e-01,  2.9538e-03],
          [-1.0087e-02,  4.6160e-02, -5.9874e-03,  ..., -8.5872e-04,
           -1.1937e-02,  5.2924e-03],
          [ 7.3707e-04, -5.9874e-03,  4.4406e-02,  ...,  3.2156e-02,
            1.6577e-03,  2.3382e-02],
          ...,
          [ 2.8749e-03, -8.5872e-04,  3.2156e-02,  ...,  3.8412e-02,
            3.7091e-03,  3.6592e-02],
          [ 1.9862e-01, -1.1937e-02,  1.6577e-03,  ...,  3.7091e-03,
            2.0519e-01,  3.1140e-03],
          [ 2.9538e-03,  5.2925e-03,  2.3382e-02,  ...,  3.6592e-02,
            3.1140e-03,  3.8679e-02]],

         [[ 5.5470e-02,  4.7662e-02, -4.9698e-03,  ...,  5.7940e-02,
           -6.9915e-03, -6.5890e-03],
          [ 4.7662e-02,  8.1808e-02, -8.8931e-03,  ...,  6.8577e-02,
           -6.3999e-03, -1.0230e-03],
          [-4.9697e-03, -8.8931e-03,  4.0781e-02,  ..., -9.5323e-03,
            3.9956e-02,  3.2565e-02],
          ...,
          [ 5.7940e-02,  6.8578e-02, -9.5322e-03,  ...,  6.9949e-02,
           -8.9076e-03, -4.7678e-03],
          [-6.9914e-03, -6.3999e-03,  3.9956e-02,  ..., -8.9075e-03,
            4.1612e-02,  3.8358e-02],
          [-6.5890e-03, -1.0230e-03,  3.2565e-02,  ..., -4.7678e-03,
            3.8358e-02,  4.3992e-02]],

         [[ 5.3481e-02,  1.5873e-03,  5.6119e-02,  ..., -6.2818e-03,
            1.1292e-03,  1.8497e-02],
          [ 1.5872e-03,  1.0184e-01,  1.0133e-02,  ...,  1.5543e-03,
           -1.8478e-04, -3.5347e-03],
          [ 5.6119e-02,  1.0133e-02,  6.4210e-02,  ..., -8.6683e-03,
           -7.0410e-04,  9.6766e-03],
          ...,
          [-6.2818e-03,  1.5543e-03, -8.6683e-03,  ...,  4.1578e-02,
           -1.1714e-02,  2.0186e-02],
          [ 1.1292e-03, -1.8478e-04, -7.0410e-04,  ..., -1.1714e-02,
            1.9402e-01,  3.0609e-03],
          [ 1.8497e-02, -3.5347e-03,  9.6765e-03,  ...,  2.0186e-02,
            3.0609e-03,  3.8419e-02]],

         ...,

         [[ 6.9100e-02, -8.6930e-03,  3.1382e-02,  ..., -9.7418e-04,
           -6.0409e-03,  2.7914e-03],
          [-8.6930e-03,  4.0286e-02,  1.3143e-02,  ...,  3.5330e-02,
            3.8896e-02, -1.0478e-02],
          [ 3.1382e-02,  1.3143e-02,  4.1285e-02,  ...,  2.6167e-02,
            1.9509e-02,  1.9247e-03],
          ...,
          [-9.7424e-04,  3.5330e-02,  2.6167e-02,  ...,  3.8882e-02,
            3.8087e-02, -6.7822e-03],
          [-6.0409e-03,  3.8896e-02,  1.9509e-02,  ...,  3.8087e-02,
            3.9592e-02, -9.2040e-03],
          [ 2.7914e-03, -1.0478e-02,  1.9246e-03,  ..., -6.7822e-03,
           -9.2040e-03,  1.1059e-01]],

         [[ 3.8453e-02, -7.1426e-03,  2.6952e-02,  ...,  2.1675e-03,
           -5.6269e-03,  2.7991e-02],
          [-7.1426e-03,  7.1407e-02, -5.0028e-04,  ...,  5.3486e-02,
            5.7866e-02, -1.7845e-03],
          [ 2.6952e-02, -5.0026e-04,  4.4375e-02,  ...,  3.5167e-03,
           -3.5785e-03,  4.5144e-03],
          ...,
          [ 2.1676e-03,  5.3486e-02,  3.5167e-03,  ...,  1.8501e-01,
            1.8557e-02, -9.1897e-03],
          [-5.6269e-03,  5.7866e-02, -3.5785e-03,  ...,  1.8557e-02,
            5.5936e-02,  8.8015e-03],
          [ 2.7991e-02, -1.7844e-03,  4.5145e-03,  ..., -9.1897e-03,
            8.8015e-03,  4.0685e-02]],

         [[ 6.7092e-02, -3.5457e-04,  2.7210e-03,  ...,  7.0470e-02,
            3.7646e-02,  5.8789e-02],
          [-3.5451e-04,  4.6443e-02,  4.7867e-02,  ...,  1.0053e-03,
           -5.8572e-03, -2.6399e-03],
          [ 2.7210e-03,  4.7867e-02,  5.9292e-02,  ...,  3.4174e-03,
           -3.8387e-03,  9.1463e-04],
          ...,
          [ 7.0470e-02,  1.0052e-03,  3.4174e-03,  ...,  7.6717e-02,
            3.2287e-02,  5.8172e-02],
          [ 3.7646e-02, -5.8572e-03, -3.8387e-03,  ...,  3.2287e-02,
            4.6615e-02,  4.3827e-02],
          [ 5.8789e-02, -2.6399e-03,  9.1466e-04,  ...,  5.8172e-02,
            4.3827e-02,  5.6518e-02]]],


        [[[ 4.7704e-02,  1.9527e-02, -1.0921e-03,  ...,  3.4585e-02,
            2.0813e-02, -3.3917e-03],
          [ 1.9527e-02,  3.7593e-02, -3.8447e-03,  ..., -1.9257e-03,
           -5.6953e-03,  2.8096e-03],
          [-1.0922e-03, -3.8447e-03,  6.7277e-02,  ...,  2.8856e-02,
            4.6415e-02,  2.2698e-02],
          ...,
          [ 3.4585e-02, -1.9257e-03,  2.8856e-02,  ...,  5.7276e-02,
            5.3737e-02, -1.3221e-02],
          [ 2.0813e-02, -5.6953e-03,  4.6415e-02,  ...,  5.3737e-02,
            5.9087e-02, -9.4508e-03],
          [-3.3917e-03,  2.8096e-03,  2.2698e-02,  ..., -1.3221e-02,
           -9.4509e-03,  1.3608e-01]],

         [[ 1.4566e-01, -9.8583e-04,  1.0714e-03,  ..., -1.1830e-02,
            2.2799e-04,  2.0530e-03],
          [-9.8583e-04,  8.8150e-02,  8.7178e-03,  ...,  3.3907e-03,
            9.8385e-02, -2.3380e-03],
          [ 1.0714e-03,  8.7180e-03,  3.7295e-02,  ..., -6.3031e-03,
           -9.9213e-03,  3.5124e-02],
          ...,
          [-1.1830e-02,  3.3907e-03, -6.3031e-03,  ...,  5.8137e-02,
            4.2809e-03, -4.1362e-03],
          [ 2.2799e-04,  9.8385e-02, -9.9213e-03,  ...,  4.2809e-03,
            1.8051e-01, -1.2985e-02],
          [ 2.0530e-03, -2.3380e-03,  3.5125e-02,  ..., -4.1361e-03,
           -1.2985e-02,  3.7539e-02]],

         [[ 2.8302e-01,  2.6982e-01, -7.3539e-04,  ...,  2.6753e-03,
           -1.7292e-03,  4.8452e-03],
          [ 2.6982e-01,  2.8262e-01,  1.9421e-03,  ..., -5.1718e-03,
           -1.2654e-03,  4.8468e-03],
          [-7.3539e-04,  1.9421e-03,  4.0326e-02,  ...,  2.2300e-03,
           -5.9182e-03,  3.7337e-02],
          ...,
          [ 2.6754e-03, -5.1717e-03,  2.2300e-03,  ...,  5.1946e-02,
            2.3076e-03, -6.3845e-03],
          [-1.7292e-03, -1.2654e-03, -5.9182e-03,  ...,  2.3076e-03,
            6.7828e-02, -4.3074e-03],
          [ 4.8452e-03,  4.8468e-03,  3.7337e-02,  ..., -6.3845e-03,
           -4.3074e-03,  4.5347e-02]],

         ...,

         [[ 2.5071e-01,  1.4107e-02, -1.2660e-02,  ...,  2.0181e-03,
            2.9075e-03,  5.7274e-03],
          [ 1.4107e-02,  5.2034e-02,  1.9294e-02,  ..., -1.4354e-03,
           -2.5633e-03, -7.3799e-03],
          [-1.2660e-02,  1.9295e-02,  3.7224e-02,  ..., -5.4933e-03,
           -4.5041e-03,  6.2667e-03],
          ...,
          [ 2.0181e-03, -1.4354e-03, -5.4932e-03,  ...,  5.8300e-02,
            5.7492e-02,  4.1547e-02],
          [ 2.9075e-03, -2.5633e-03, -4.5040e-03,  ...,  5.7492e-02,
            5.7663e-02,  4.5333e-02],
          [ 5.7274e-03, -7.3799e-03,  6.2665e-03,  ...,  4.1547e-02,
            4.5333e-02,  5.1940e-02]],

         [[ 6.1903e-02,  3.2676e-02,  1.5708e-03,  ...,  2.8630e-03,
            3.4088e-02, -1.5306e-03],
          [ 3.2676e-02,  5.5653e-02,  4.5679e-03,  ..., -2.6547e-03,
           -9.7077e-03,  9.3156e-04],
          [ 1.5708e-03,  4.5679e-03,  1.4196e-01,  ...,  6.6510e-02,
           -1.9973e-03,  8.9521e-02],
          ...,
          [ 2.8630e-03, -2.6547e-03,  6.6510e-02,  ...,  6.4042e-02,
           -8.0032e-04,  3.6737e-03],
          [ 3.4087e-02, -9.7077e-03, -1.9973e-03,  ..., -8.0032e-04,
            1.0537e-01,  5.7995e-04],
          [-1.5306e-03,  9.3156e-04,  8.9521e-02,  ...,  3.6737e-03,
            5.7995e-04,  2.7525e-01]],

         [[ 1.3769e-01,  1.4922e-01, -1.8591e-03,  ..., -1.1034e-03,
           -3.5165e-03,  2.8663e-03],
          [ 1.4922e-01,  1.6726e-01, -1.4776e-03,  ..., -1.5248e-03,
           -8.0332e-03,  4.3722e-03],
          [-1.8591e-03, -1.4776e-03,  1.1476e-01,  ...,  6.8247e-02,
            2.8611e-03, -1.2489e-02],
          ...,
          [-1.1034e-03, -1.5248e-03,  6.8247e-02,  ...,  7.5415e-02,
            4.9692e-04,  1.4356e-03],
          [-3.5165e-03, -8.0333e-03,  2.8611e-03,  ...,  4.9692e-04,
            3.7351e-02, -1.1735e-03],
          [ 2.8662e-03,  4.3722e-03, -1.2489e-02,  ...,  1.4356e-03,
           -1.1735e-03,  5.3157e-02]]],


        ...,


        [[[ 5.9421e-02,  1.6581e-03,  6.0225e-02,  ..., -9.2409e-03,
            4.8738e-03, -9.7021e-03],
          [ 1.6581e-03,  8.0389e-02,  5.6300e-04,  ..., -8.6126e-03,
           -2.7495e-03, -8.9058e-03],
          [ 6.0225e-02,  5.6300e-04,  6.8143e-02,  ..., -1.0957e-02,
           -2.6593e-03, -1.1023e-02],
          ...,
          [-9.2408e-03, -8.6126e-03, -1.0957e-02,  ...,  5.2312e-02,
            3.8694e-02,  5.2609e-02],
          [ 4.8738e-03, -2.7495e-03, -2.6593e-03,  ...,  3.8694e-02,
            4.2302e-02,  3.7911e-02],
          [-9.7021e-03, -8.9058e-03, -1.1023e-02,  ...,  5.2609e-02,
            3.7911e-02,  5.2991e-02]],

         [[ 3.9393e-01,  2.2062e-04,  1.5159e-01,  ...,  4.1347e-01,
            1.7569e-05, -6.8451e-05],
          [ 2.2062e-04,  7.1479e-02,  7.2592e-05,  ...,  1.2855e-04,
            2.5229e-02,  1.9669e-02],
          [ 1.5159e-01,  7.2593e-05,  2.5124e-01,  ...,  1.1175e-01,
           -1.2823e-03, -1.1744e-03],
          ...,
          [ 4.1347e-01,  1.2855e-04,  1.1175e-01,  ...,  4.5561e-01,
            1.6977e-04,  1.0276e-04],
          [ 1.7570e-05,  2.5229e-02, -1.2823e-03,  ...,  1.6977e-04,
            4.3450e-02,  4.2231e-02],
          [-6.8450e-05,  1.9669e-02, -1.1744e-03,  ...,  1.0277e-04,
            4.2231e-02,  4.1794e-02]],

         [[ 4.1198e-02,  4.0542e-02, -5.5254e-03,  ...,  2.4459e-02,
           -8.1799e-03, -1.5319e-03],
          [ 4.0542e-02,  4.0392e-02, -4.9437e-03,  ...,  2.7666e-02,
           -7.8222e-03, -4.1005e-04],
          [-5.5254e-03, -4.9437e-03,  5.8272e-02,  ...,  1.0571e-02,
            1.2269e-03, -1.0966e-03],
          ...,
          [ 2.4459e-02,  2.7666e-02,  1.0571e-02,  ...,  4.5997e-02,
           -3.7923e-06,  3.4627e-03],
          [-8.1798e-03, -7.8222e-03,  1.2269e-03,  ..., -3.7923e-06,
            1.2368e-01,  9.6138e-02],
          [-1.5319e-03, -4.1006e-04, -1.0966e-03,  ...,  3.4627e-03,
            9.6138e-02,  2.3274e-01]],

         ...,

         [[ 1.6497e-01, -1.0975e-02,  9.0244e-04,  ...,  1.7571e-03,
           -1.4008e-03,  2.0611e-02],
          [-1.0975e-02,  5.8261e-02, -3.5737e-03,  ...,  3.5700e-02,
           -1.8892e-03,  2.0605e-02],
          [ 9.0243e-04, -3.5737e-03,  3.9832e-02,  ...,  1.4757e-02,
            1.7908e-02, -7.1197e-04],
          ...,
          [ 1.7571e-03,  3.5700e-02,  1.4757e-02,  ...,  5.6389e-02,
           -1.1194e-02, -7.9753e-03],
          [-1.4008e-03, -1.8892e-03,  1.7908e-02,  ..., -1.1194e-02,
            6.2241e-02,  2.4000e-03],
          [ 2.0612e-02,  2.0605e-02, -7.1199e-04,  ..., -7.9753e-03,
            2.4000e-03,  7.3486e-02]],

         [[ 5.3396e-02, -9.9451e-03,  5.4543e-02,  ..., -2.2478e-03,
            3.1685e-03,  1.9532e-02],
          [-9.9450e-03,  5.8372e-02, -1.0106e-02,  ...,  4.1144e-02,
            2.2140e-03,  1.4155e-02],
          [ 5.4543e-02, -1.0105e-02,  5.6659e-02,  ..., -5.0847e-03,
            3.2081e-03,  1.5721e-02],
          ...,
          [-2.2475e-03,  4.1144e-02, -5.0848e-03,  ...,  4.9683e-02,
           -8.2029e-03,  3.4452e-02],
          [ 3.1685e-03,  2.2141e-03,  3.2080e-03,  ..., -8.2028e-03,
            6.9885e-02, -3.8199e-03],
          [ 1.9532e-02,  1.4155e-02,  1.5721e-02,  ...,  3.4452e-02,
           -3.8199e-03,  4.0315e-02]],

         [[ 2.0610e-01, -1.2810e-04,  8.8189e-04,  ...,  1.1971e-03,
            1.9268e-01,  1.0588e-01],
          [-1.2810e-04,  7.0992e-02,  5.7632e-04,  ...,  7.9548e-04,
           -3.8124e-04, -1.2715e-03],
          [ 8.8189e-04,  5.7644e-04,  5.8386e-02,  ..., -3.4916e-03,
            1.6682e-03,  3.9304e-03],
          ...,
          [ 1.1969e-03,  7.9549e-04, -3.4916e-03,  ...,  7.6947e-02,
            9.1501e-03,  5.0463e-02],
          [ 1.9268e-01, -3.8124e-04,  1.6682e-03,  ...,  9.1501e-03,
            1.8497e-01,  1.1556e-01],
          [ 1.0588e-01, -1.2715e-03,  3.9304e-03,  ...,  5.0463e-02,
            1.1556e-01,  1.2224e-01]]],


        [[[ 6.1973e-02,  8.1589e-03,  3.4157e-02,  ...,  1.7163e-03,
            7.4451e-03, -9.5267e-04],
          [ 8.1589e-03,  5.0460e-02, -4.2774e-03,  ...,  1.6156e-04,
            5.0283e-02,  2.5469e-03],
          [ 3.4156e-02, -4.2775e-03,  4.3208e-02,  ..., -8.6064e-03,
           -4.3119e-03, -8.0910e-05],
          ...,
          [ 1.7164e-03,  1.6156e-04, -8.6064e-03,  ...,  2.9064e-01,
            9.1983e-05, -1.0412e-04],
          [ 7.4450e-03,  5.0283e-02, -4.3119e-03,  ...,  9.1985e-05,
            5.0135e-02,  2.5786e-03],
          [-9.5267e-04,  2.5469e-03, -8.0910e-05,  ..., -1.0412e-04,
            2.5786e-03,  9.7995e-01]],

         [[ 6.7878e-02,  6.7147e-02,  5.3423e-02,  ..., -6.6037e-04,
            3.0028e-02,  5.9069e-03],
          [ 6.7147e-02,  6.7436e-02,  5.7009e-02,  ..., -8.6678e-04,
            2.5093e-02,  9.5530e-03],
          [ 5.3423e-02,  5.7010e-02,  6.0562e-02,  ..., -9.6241e-04,
            8.0493e-03,  2.5771e-02],
          ...,
          [-6.6037e-04, -8.6678e-04, -9.6241e-04,  ...,  9.8407e-01,
            2.8142e-03,  7.9932e-05],
          [ 3.0028e-02,  2.5093e-02,  8.0492e-03,  ...,  2.8142e-03,
            4.9518e-02, -4.7436e-03],
          [ 5.9068e-03,  9.5530e-03,  2.5771e-02,  ...,  7.9932e-05,
           -4.7436e-03,  4.7115e-02]],

         [[ 4.8640e-02,  4.9163e-04,  2.9780e-04,  ...,  1.1768e-02,
            4.9429e-02, -5.1792e-04],
          [ 4.9164e-04,  6.8149e-02,  1.3854e-02,  ...,  5.7187e-04,
            3.4281e-04,  7.2365e-03],
          [ 2.9778e-04,  1.3854e-02,  4.8063e-02,  ..., -3.9784e-03,
            5.5370e-04,  4.8235e-02],
          ...,
          [ 1.1768e-02,  5.7184e-04, -3.9783e-03,  ...,  5.0074e-02,
            9.9938e-03, -2.6288e-03],
          [ 4.9429e-02,  3.4280e-04,  5.5371e-04,  ...,  9.9939e-03,
            5.0375e-02, -2.1750e-04],
          [-5.1792e-04,  7.2365e-03,  4.8235e-02,  ..., -2.6287e-03,
           -2.1752e-04,  4.9897e-02]],

         ...,

         [[ 5.4585e-02, -5.1755e-03, -1.0463e-03,  ...,  1.2076e-03,
           -1.8401e-04,  1.6032e-03],
          [-5.1755e-03,  6.6444e-02,  4.6162e-03,  ..., -7.5896e-03,
            8.1924e-03,  6.1627e-02],
          [-1.0463e-03,  4.6162e-03,  2.4773e-01,  ...,  1.7105e-02,
           -1.1299e-02,  3.8292e-03],
          ...,
          [ 1.2076e-03, -7.5897e-03,  1.7105e-02,  ...,  5.4244e-02,
            1.8883e-02, -4.7730e-03],
          [-1.8399e-04,  8.1925e-03, -1.1299e-02,  ...,  1.8883e-02,
            4.3183e-02,  1.4270e-03],
          [ 1.6034e-03,  6.1627e-02,  3.8292e-03,  ..., -4.7730e-03,
            1.4270e-03,  6.2028e-02]],

         [[ 4.0147e-02, -8.8278e-03,  1.7977e-03,  ..., -3.9404e-03,
            4.7157e-04,  3.5606e-02],
          [-8.8277e-03,  1.9934e-01, -1.3471e-03,  ...,  1.9746e-03,
           -1.2542e-03, -1.4240e-02],
          [ 1.7977e-03, -1.3471e-03,  6.2808e-02,  ...,  1.3420e-02,
            5.0706e-02, -2.2878e-04],
          ...,
          [-3.9404e-03,  1.9746e-03,  1.3420e-02,  ...,  4.9921e-02,
            2.9776e-02, -3.0164e-03],
          [ 4.7158e-04, -1.2542e-03,  5.0706e-02,  ...,  2.9776e-02,
            5.1519e-02, -3.1342e-03],
          [ 3.5606e-02, -1.4240e-02, -2.2878e-04,  ..., -3.0164e-03,
           -3.1342e-03,  4.9039e-02]],

         [[ 6.4415e-02,  5.1045e-02,  2.1636e-02,  ...,  5.6377e-02,
           -5.8883e-03, -1.7888e-03],
          [ 5.1045e-02,  5.4129e-02,  3.9820e-03,  ...,  5.4954e-02,
           -9.5260e-03, -1.3057e-03],
          [ 2.1636e-02,  3.9819e-03,  5.3308e-02,  ...,  8.0321e-03,
            2.3812e-03,  4.9363e-03],
          ...,
          [ 5.6377e-02,  5.4954e-02,  8.0321e-03,  ...,  5.7126e-02,
           -9.1327e-03, -1.5622e-03],
          [-5.8883e-03, -9.5260e-03,  2.3812e-03,  ..., -9.1327e-03,
            9.5133e-02, -9.7311e-05],
          [-1.7888e-03, -1.3057e-03,  4.9363e-03,  ..., -1.5622e-03,
           -9.7311e-05,  9.3298e-01]]],


        [[[ 5.0698e-02,  1.3989e-03,  1.6138e-03,  ..., -1.1362e-03,
           -6.2933e-03, -5.9010e-03],
          [ 1.3989e-03,  7.0238e-02, -5.1442e-04,  ...,  7.0505e-02,
            1.6960e-03,  1.5981e-03],
          [ 1.6138e-03, -5.1443e-04,  7.5609e-02,  ..., -6.3997e-04,
            1.1834e-02,  8.2032e-03],
          ...,
          [-1.1363e-03,  7.0505e-02, -6.3997e-04,  ...,  7.1901e-02,
            1.7205e-03,  1.7380e-03],
          [-6.2934e-03,  1.6960e-03,  1.1834e-02,  ...,  1.7205e-03,
            5.3345e-02,  5.2950e-02],
          [-5.9010e-03,  1.5981e-03,  8.2033e-03,  ...,  1.7380e-03,
            5.2950e-02,  5.3046e-02]],

         [[ 5.2606e-02,  4.2860e-02,  4.6489e-03,  ..., -6.7611e-03,
            5.3684e-02,  3.9976e-02],
          [ 4.2860e-02,  4.6144e-02,  2.1818e-02,  ..., -1.3292e-05,
            3.9485e-02,  4.5701e-02],
          [ 4.6488e-03,  2.1818e-02,  5.1096e-02,  ...,  3.7039e-02,
           -5.5909e-04,  2.5455e-02],
          ...,
          [-6.7610e-03, -1.3351e-05,  3.7039e-02,  ...,  5.3065e-02,
           -7.5928e-03,  2.1577e-03],
          [ 5.3685e-02,  3.9485e-02, -5.5903e-04,  ..., -7.5927e-03,
            5.6580e-02,  3.5900e-02],
          [ 3.9976e-02,  4.5701e-02,  2.5455e-02,  ...,  2.1576e-03,
            3.5900e-02,  4.5778e-02]],

         [[ 9.6489e-02,  9.0478e-02, -2.8430e-04,  ..., -8.1079e-03,
            3.3238e-03,  1.0301e-01],
          [ 9.0478e-02,  8.7239e-02, -4.3448e-04,  ..., -7.3521e-03,
            2.5341e-03,  9.1497e-02],
          [-2.8430e-04, -4.3448e-04,  8.4674e-02,  ...,  2.3881e-03,
           -8.7128e-03,  2.1667e-05],
          ...,
          [-8.1080e-03, -7.3522e-03,  2.3881e-03,  ...,  5.0632e-02,
            1.1367e-02, -7.7680e-03],
          [ 3.3238e-03,  2.5341e-03, -8.7128e-03,  ...,  1.1367e-02,
            5.0540e-02,  3.9543e-03],
          [ 1.0301e-01,  9.1497e-02,  2.1668e-05,  ..., -7.7680e-03,
            3.9543e-03,  1.2156e-01]],

         ...,

         [[ 1.2528e-01, -1.6069e-03, -7.3090e-05,  ..., -1.4449e-03,
           -7.9463e-03,  9.5200e-04],
          [-1.6069e-03,  7.1068e-02,  3.7241e-02,  ...,  6.7512e-02,
            1.6517e-03,  3.6533e-02],
          [-7.3090e-05,  3.7241e-02,  7.2595e-02,  ...,  2.3583e-02,
            1.0771e-03, -1.8406e-04],
          ...,
          [-1.4449e-03,  6.7512e-02,  2.3583e-02,  ...,  6.8988e-02,
            1.1563e-04,  4.6189e-02],
          [-7.9464e-03,  1.6517e-03,  1.0771e-03,  ...,  1.1563e-04,
            5.5495e-02, -5.2884e-03],
          [ 9.5199e-04,  3.6533e-02, -1.8409e-04,  ...,  4.6189e-02,
           -5.2884e-03,  5.2793e-02]],

         [[ 6.9853e-02,  6.8248e-02,  1.0484e-03,  ..., -4.7664e-03,
            4.7711e-02,  5.5893e-02],
          [ 6.8248e-02,  6.8967e-02,  1.7480e-04,  ..., -3.7533e-03,
            3.5442e-02,  6.2084e-02],
          [ 1.0483e-03,  1.7481e-04,  5.2603e-02,  ...,  7.4525e-03,
            1.4454e-03, -3.4731e-03],
          ...,
          [-4.7665e-03, -3.7532e-03,  7.4526e-03,  ...,  4.5739e-02,
           -1.4386e-03,  3.0388e-03],
          [ 4.7711e-02,  3.5442e-02,  1.4454e-03,  ..., -1.4386e-03,
            1.1359e-01,  6.7297e-03],
          [ 5.5893e-02,  6.2084e-02, -3.4730e-03,  ...,  3.0388e-03,
            6.7298e-03,  7.0613e-02]],

         [[ 6.9162e-02, -5.0208e-03, -1.0198e-02,  ..., -1.0686e-02,
            3.5048e-02,  6.4166e-02],
          [-5.0208e-03,  5.3178e-02,  4.5933e-02,  ...,  4.2775e-02,
            1.5740e-03, -1.4487e-03],
          [-1.0198e-02,  4.5933e-02,  5.2096e-02,  ...,  5.1632e-02,
           -3.2156e-04, -7.0693e-03],
          ...,
          [-1.0686e-02,  4.2775e-02,  5.1632e-02,  ...,  5.1836e-02,
           -9.9871e-04, -8.1672e-03],
          [ 3.5049e-02,  1.5740e-03, -3.2156e-04,  ..., -9.9871e-04,
            6.9687e-02,  5.4590e-02],
          [ 6.4166e-02, -1.4487e-03, -7.0693e-03,  ..., -8.1672e-03,
            5.4590e-02,  7.0837e-02]]]], grad_fn=<ExpandBackward0>)