Toy NB 

Input : variable latente z de dimension Dz

Output : 

moyenne mu_x de dimension Dx
  
covariance sigma_x (Dx, Dx) définie positive et non diagonale
  
Doit scaler avec batch, sequence_length :

(Dz) => (Dx) + (Dx,Dx)

(L,Dz) => (L,Dx) + (L,Dx,Dx)

(B,L,Dz) => (B,L,Dx) + (B,L,Dx,Dx)

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.distributions.multivariate_normal import MultivariateNormal as MVN

In [2]:
Dz = 8
Dx = 16
B = 32
N = 500

In [3]:
# diagonale de la matrice de covariance : pas de souci.

diagonale = nn.Linear(Dz, Dx)

z = torch.randn(Dz)
x = diagonale(z)
print(f"z shape: {z.shape} => x shape: {x.shape}")

z = torch.randn(N,Dz)
x = diagonale(z)
print(f"z shape: {z.shape} => x shape: {x.shape}")

z = torch.randn(B,N,Dz)
x = diagonale(z)
print(f"z shape: {z.shape} => x shape: {x.shape}")


z shape: torch.Size([8]) => x shape: torch.Size([16])
z shape: torch.Size([500, 8]) => x shape: torch.Size([500, 16])
z shape: torch.Size([32, 500, 8]) => x shape: torch.Size([32, 500, 16])


In [4]:
# matric triangulaire inferieure de la matrice de covariance

class TriangularLower(nn.Module):
    def __init__(self, Dz=Dz, Dx=Dx):
        super(TriangularLower, self).__init__()
        self.full = nn.Sequential(
            nn.Linear(Dz, Dx * Dx),
            nn.Unflatten(-1, (Dx, Dx)),
            )

    def forward(self, x):
        return torch.tril(self.full(x))
    
triangulaire = TriangularLower()

In [5]:
z = torch.randn(Dz)
L = triangulaire(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z = torch.randn(N,Dz)
L = triangulaire(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z = torch.randn(B,N,Dz)
L = triangulaire(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z shape: torch.Size([8]) => L shape: torch.Size([16, 16])
z shape: torch.Size([500, 8]) => L shape: torch.Size([500, 16, 16])
z shape: torch.Size([32, 500, 8]) => L shape: torch.Size([32, 500, 16, 16])


In [6]:
print(L[4, 3])
print(L.transpose(-1, -2)[4, 3])

tensor([[-0.1317,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.7879,  0.0343,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0703, -0.1561,  0.8436,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.2832,  0.3787,  0.9443, -0.2443,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1162, -0.1366,  0.5684,  0.8008, -0.6393,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9695, -0.1239, -0.8288, -0.6332,  0.3228,  0.6021,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.4204, -0.0

In [7]:
class LowerTriangularForCovarianceMatrix(nn.Module):
    def __init__(self, Dz=Dz, Dx=Dx):
        super(LowerTriangularForCovarianceMatrix, self).__init__()
        self.diagonale = nn.Linear(Dz, Dx)
        self.full = nn.Sequential(
            nn.Linear(Dz, Dx * Dx),
            nn.Unflatten(-1, (Dx, Dx)),
            )

    def forward(self, z):
        D = torch.diag_embed(torch.exp(self.diagonale(z)))  # Diagonal elements > 0
        T = torch.tril(self.full(z), diagonal=-1)  # Lower triangular matrix without diagonal
        L = D + T

        return L
    
lower_triangular = LowerTriangularForCovarianceMatrix()

In [8]:
z = torch.randn(Dz)
L = lower_triangular(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z = torch.randn(N,Dz)
L = lower_triangular(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z = torch.randn(B,N,Dz)
L = lower_triangular(z)
print(f"z shape: {z.shape} => L shape: {L.shape}")

z shape: torch.Size([8]) => L shape: torch.Size([16, 16])
z shape: torch.Size([500, 8]) => L shape: torch.Size([500, 16, 16])
z shape: torch.Size([32, 500, 8]) => L shape: torch.Size([32, 500, 16, 16])


In [9]:
sample_L = L[4, 3]

print(sample_L)

tensor([[ 4.1057,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.1622,  1.7170,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9138, -0.0357,  0.8783,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.3027,  0.0111,  0.4641,  1.6057,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.5145, -1.0874, -0.9443, -0.4787,  1.1992,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2096, -0.7371, -0.4546, -0.3539,  0.3270,  0.2552,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.4632, -0.6

In [10]:
# Instantiation de MVN avec des lower triangular matrices
# => the use of scale_tril is mandatory to avoid errors when instanciating the MVN class !
# using covariance_matrix instead of scale_tril may lead to numerical errors.

z = torch.randn(Dz)
mu = diagonale(z)
L = lower_triangular(z)
print(f"z shape: {z.shape} => mu shape: {mu.shape}, covar shape: {L.shape}")
mvn = MVN(loc=mu, scale_tril=L)
print(f"mvn loc: {mvn.loc.shape}, covariance_matrix: {mvn.covariance_matrix.shape}")
print(f"mvn batch_shape: {mvn.batch_shape}, event_shape: {mvn.event_shape}")
print()

z = torch.randn(N,Dz)
mu = diagonale(z)
L = lower_triangular(z)
print(f"z shape: {z.shape} => mu shape: {mu.shape}, covar shape: {L.shape}")
mvn = MVN(loc=mu, scale_tril=L)
print(f"mvn loc: {mvn.loc.shape}, covariance_matrix: {mvn.covariance_matrix.shape}")
print(f"mvn batch_shape: {mvn.batch_shape}, event_shape: {mvn.event_shape}")
print()

z = torch.randn(B,N,Dz)
mu = diagonale(z)
L = lower_triangular(z)
print(f"z shape: {z.shape} => mu shape: {mu.shape}, covar shape: {L.shape}")
mvn = MVN(loc=mu, scale_tril=L)
print(f"mvn loc: {mvn.loc.shape}, covariance_matrix: {mvn.covariance_matrix.shape}")
print(f"mvn batch_shape: {mvn.batch_shape}, event_shape: {mvn.event_shape}")
print()

z shape: torch.Size([8]) => mu shape: torch.Size([16]), covar shape: torch.Size([16, 16])
mvn loc: torch.Size([16]), covariance_matrix: torch.Size([16, 16])
mvn batch_shape: torch.Size([]), event_shape: torch.Size([16])

z shape: torch.Size([500, 8]) => mu shape: torch.Size([500, 16]), covar shape: torch.Size([500, 16, 16])
mvn loc: torch.Size([500, 16]), covariance_matrix: torch.Size([500, 16, 16])
mvn batch_shape: torch.Size([500]), event_shape: torch.Size([16])

z shape: torch.Size([32, 500, 8]) => mu shape: torch.Size([32, 500, 16]), covar shape: torch.Size([32, 500, 16, 16])
mvn loc: torch.Size([32, 500, 16]), covariance_matrix: torch.Size([32, 500, 16, 16])
mvn batch_shape: torch.Size([32, 500]), event_shape: torch.Size([16])



In [12]:
# Test for error-free instanciations of MVNs by brute force sampling...

N_TESTS = int(1e+4)

for _ in tqdm(range(N_TESTS)):
    
    z = torch.randn(B,N,Dz)
    mu = diagonale(z)
    L = lower_triangular(z)
    mvn = MVN(loc=mu, scale_tril=L)

100%|██████████| 10000/10000 [01:17<00:00, 128.44it/s]
