In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pandas as pd

In [2]:
from typing import Any, List

General case:
- Choose a prior for $Z$: $p(Z)$.
- Choose an observation model: $p_\theta(X|Z)$
- Choose a variational posterior: $q_{\gamma}(\mathbf{z} | \mathbf{x})$

- Choose a missing model: $p_{\phi}(\mathbf{S} | \mathbf{X^o, X^m})$


The ELBO in the MNAR case is

$$ E_{(\mathbf{z}_1, \mathbf{x}_1^m)...(\mathbf{z}_K, \mathbf{x}_K^m)} \left[ \log \frac{1}{K} \sum_{k=1}^K \frac{p_{\phi}(\mathbf{s} | \mathbf{x}^o, \mathbf{x}_k^m) p_{\theta}(\mathbf{x}^o | \mathbf{z}_k) p(\mathbf{z}_k)}{q_{\gamma}(\mathbf{z} | \mathbf{x}^o)} \right]$$

### Classic case
The model we are building has a Gaussian prior and a Gaussian observation model (also the decoder ($z \rightarrow x$) ),

$$ p(\mathbf{z}) = \mathcal{N}(\mathbf{z} | \mathbf{0}, \mathbf{I})$$

$$ p_\theta(\mathbf{x} | \mathbf{z}) = \mathcal{N}(\mathbf{x} | \mathbf{\mu}_{\theta}(\mathbf{z}), \sigma^2\mathbf{I})$$

$$ p_\theta(\mathbf{x}) = \int p_\theta(\mathbf{x} | \mathbf{z})p(\mathbf{z}) d\mathbf{z}$$

where $\mathbf{\mu}_{\theta}(\mathbf{z}): \mathbb{R}^d \rightarrow \mathbb{R}^p $ in general is a deep neural net, but in this case is a linear mapping, $\mathbf{\mu} = \mathbf{Wz + b}$.

The variational posterior (also the encoder ($x \rightarrow z$) ) is also Gaussian

$$q_{\gamma}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} | \mu_{\gamma}(\mathbf{x}), \sigma_{\gamma}(\mathbf{x})^2 \mathbf{I})$$

If the missing process is *missing at random*, it is ignorable and the ELBO becomes, as described in [the MIWAE paper](https://arxiv.org/abs/1812.02633)

$$ E_{\mathbf{z}_1...\mathbf{z}_K} \left[ \log \frac{1}{K}\sum_{k=1}^K \frac{p_{\theta}(\mathbf{x^o} | \mathbf{z}_k)p(\mathbf{z}_k)}{q_{\gamma}(\mathbf{z}_k | \mathbf{x^o})} \right] $$

When the missing process is MNAR it is non-ignorable and we need to include the missing model. In this example we include the missing model as a logistic regression in each feature dimension

$$ p_{\phi}(\mathbf{s} | \mathbf{x^o, x^m}) = \text{Bern}(\mathbf{s} | \pi_{\phi}(\mathbf{x^o, x^m}))$$

$$ \pi_{\phi, j}(x_j) = \frac{1}{1 + e^{-\text{logits}_j}} $$

$$ \text{logits}_j = W_j (x_j - b_j) $$

The ELBO in the MNAR case becomes

$$ E_{(\mathbf{z}_1, \mathbf{x}_1^m)...(\mathbf{z}_K, \mathbf{x}_K^m)} \left[ \log \frac{1}{K} \sum_{k=1}^K \frac{p_{\phi}(\mathbf{s} | \mathbf{x}^o, \mathbf{x}_k^m) p_{\theta}(\mathbf{x}^o | \mathbf{z}_k) p(\mathbf{z}_k)}{q_{\gamma}(\mathbf{z} | \mathbf{x}^o)} \right]$$

with $ z \sim q_{\gamma}(z|x^o), x^m\sim p_\theta(x^m|z)$

### Constant to define

 - $K$ = $n_{\text{samples}}$ the number of sample to estimate the expectation
 - $n_{\text{latent}}$ the dimension of the latent space where $z$ lives


### Load data
Here we use the white-wine dataset from the UCI database

In [3]:
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
data = np.array(pd.read_csv(url, low_memory=False, sep=';'))
# ---- drop the classification attribute
data = data[:, :-1]

### Settings

In [4]:
N, D = data.shape
n_latent = D - 1
n_hidden = 128
n_samples = 20
max_iter = 30000
batch_size = 16

### Standardize data

In [5]:
# ---- standardize data
data = data - np.mean(data, axis=0)
data = data / np.std(data, axis=0)

# ---- random permutation
p = np.random.permutation(N)
data = data[p, :]

# ---- we use the full dataset for training here, but you can make a train-val split
Xtrain = data.copy()
Xval = Xtrain.copy()

### Introduce missing 
Here we denote
- Xnan: data matrix with np.nan as the missing entries
- Xz: data matrix with 0 as the missing entries
- S: missing mask 

The missing process depends on the missing data itself:
- in half the features, set the feature value to missing when it is higher than the feature mean

In [6]:
# ---- introduce missing process
Xnan = Xtrain.copy()
Xz = Xtrain.copy()

mean = np.mean(Xnan[:, :int(D / 2)], axis=0)
ix_larger_than_mean = Xnan[:, :int(D / 2)] > mean

Xnan[:, :int(D / 2)][ix_larger_than_mean] = np.nan
Xz[:, :int(D / 2)][ix_larger_than_mean] = 0

S = np.array(~np.isnan(Xnan), dtype=np.float32)

In [7]:
def check_nan(tens,name=None):
    if torch.isnan(tens).any().item():
        print(name)
        print(tens)

In [8]:
class Clip(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
    def forward(self, x):
        return torch.clip(x,-10,10)

In [9]:
class Distribution():
    def __init__(self) -> None:
        
        pass
    def rsample(self, sample_shape):
        print("Not implemented")
        pass
    def log_prob(self, value):
        print("Not implemented")
        pass

In [30]:
a = torch.FloatTensor([1,2,3])
a
torch.FloatTensor([a])

ValueError: only one element tensors can be converted to Python scalars

In [24]:
class GaussDistribution(Distribution):
    def __init__(self, loc: Any = 0., scale: Any = 1.) -> None:
        super().__init__()

        self.mu = torch.tensor(loc, dtype=torch.float32)
        self.sigma = torch.tensor(scale, dtype=torch.float32)

    def rsample(self, sample_shape):
        samples = self.mu + self.sigma * torch.randn(sample_shape, dtype=self.mu.dtype)
        return samples

    def log_prob(self, value):
        eps = torch.finfo(torch.float32).eps

        log_p = - 0.5 * torch.log(2 * np.pi) - 0.5 * torch.log(self.sigma**2) \
                      - 0.5 * torch.square(value - self.mu) / (self.sigma**2 + eps)
        return log_p 



In [26]:
class notMIWAE(nn.Module):
    #Only Gaussian and Bern for the moment
    def __init__(self, input_size = 10, n_latent = 20, n_samples = 10):
        super(notMIWAE, self).__init__()

        self.n_input = input_size
        self.n_latent = n_latent
        self.n_samples = n_samples

        self.encoder_mu = nn.Linear(in_features=input_size, out_features=n_latent)
        self.encoder_logsigma = nn.Sequential(nn.Linear(in_features=input_size, out_features=n_latent),Clip())
        
        self.decoder_mu = nn.Linear(in_features=n_latent, out_features=input_size)
        # self.decoder_logsigma = nn.Linear(in_features=n_latent, out_features=input_size)

        # Missing mechanism
        self.logits = nn.Linear(in_features=input_size, out_features=input_size)

        self.sigma = torch.ones(n_latent)

        self.prior = GaussDistribution(loc = 0., scale = 1.) # torch.distributions.normal.Normal(loc = 0., scale = 1.)
        

    def elbo(self, x, s):
        """
        x : the input of size (batch, input_size)
        s : the mask of size (batch, input_size) s[i,j] = 1 if x[i,j] exists else 0
        """
        
        z_mu = self.encoder_mu(x) # (batch, n_latent)
        z_sigma = torch.sqrt(torch.exp(self.encoder_logsigma(x))) # (batch, n_latent)
       
        
        law_z_given_x= torch.distributions.normal.Normal(loc = z_mu, scale = z_sigma) # Distribution with parameter of size (batch, n_latent)

        z_samples = law_z_given_x.rsample((self.n_samples,1)).squeeze() # (n_samples, batch, n_latent)

        log_prob_z_given_x = law_z_given_x.log_prob(z_samples).sum(dim=-1) # (n_samples, batch)
        
        
        z_samples = z_samples.transpose(0,1) # (batch, n_samples, n_latent)
        log_prob_z_given_x = log_prob_z_given_x.transpose(0,1) # (batch, n_samples)

        law_z = self.prior

        log_prob_z = law_z.log_prob(z_samples).sum(dim=-1) # (batch, n_samples)

        x_mu = self.decoder_mu(z_samples) # (batch, n_samples, input_size)

        x_sigma = 1 # torch.sqrt(torch.exp(self.decoder_logsigma(z_samples) + 1e-5)) # (batch, n_samples, input_size)

        law_x_given_z = torch.distributions.normal.Normal(loc = x_mu, scale = x_sigma) # Distribution with parameter of size (batch, n_samples, input_size)

        x_samples  = law_x_given_z.rsample().squeeze() # (batch, n_samples, input_size)

        log_prob_x_given_z = (law_x_given_z.log_prob(x.unsqueeze(1)) * s.unsqueeze(1)).sum(dim=-1) # (batch, n_samples)


        
        # Missing mechanism
        
        # We recreate the x_sample using the real x we know (x_o) and the x_samples we created from z (x_m).
        mixed_x_samples = x_samples * (1-s).unsqueeze(1) + (x*s).unsqueeze(1) # (batch, n_samples, input_size)

        logits = self.logits(mixed_x_samples) # (batch, n_samples, input_size)

        law_s_given_x = torch.distributions.bernoulli.Bernoulli(logits=logits) # Distribution with parameter of size (batch, n_samples, input_size)

        log_prob_s_given_x = law_s_given_x.log_prob(s.unsqueeze(1)).sum(dim=-1) # (batch, n_samples)


        log_sum_w = torch.logsumexp(log_prob_s_given_x + log_prob_x_given_z + log_prob_z - log_prob_z_given_x, dim = 1) # (batch)
        log_mean_w = log_sum_w - torch.log(torch.Tensor([self.n_samples])) # (batch)

        
        return  - log_mean_w.mean()
    # law_z_given_x2 = torch.distributions.normal.Normal(loc=z_mu.unsqueeze(0), scale=z_sigma.unsqueeze(0))



In [27]:
N, p = Xtrain.shape
X = torch.FloatTensor(Xz)
S = torch.FloatTensor(S)
batch_size = 100
epochs = 20
model = notMIWAE(input_size=p,n_samples=20)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
    print(f'Epochs:{epoch+1}')
    p = np.random.permutation(N)
    X = X[p,:]
    S = S[p,:]
    
    for i in range(0,N,batch_size):
        print(f'{i+batch_size} / {N}', end="\r")
        X_batch = X[i:(i+batch_size)]
        S_batch = S[i:(i+batch_size)]

        if torch.isnan(X_batch).any().item():
                print('NaN X_batch')
        if torch.isnan(S_batch).any().item():
                print('NaN S_batch')
                
        elbo = model.elbo(X_batch,S_batch)
        
        optimizer.zero_grad()
        elbo.backward()
        optimizer.step()
        for param in model.parameters():
            if torch.isnan(param).any().item():
                print('NaN parameter')
                print(torch.isnan(param).any().item())
    
    print('loss', model.elbo(X,S).item())
        
        


TypeError: new(): data must be a sequence (got float)

In [None]:
for param in model.parameters():
    print(torch.isnan(param).any().item())

False
False
False
False
False
False
False
False


In [None]:
law_z_given_x= torch.distributions.normal.Normal(loc = torch.zeros((3,5)), scale = torch.ones((3,5)))

z_samples = law_z_given_x.sample((1,1)).squeeze()

#check size probably need to transpose
print(z_samples)
log_prob_z_given_x = law_z_given_x.log_prob(z_samples)


tensor([[-0.0074, -0.2796,  0.6426,  1.1826, -1.0374],
        [-0.5127, -0.9160, -0.4352, -1.6984, -1.0139],
        [-1.3056, -1.3189, -0.7794, -1.4175,  0.3863]])


In [None]:
print( log_prob_z_given_x)

tensor([[-0.9190, -0.9580, -1.1254, -1.6182, -1.4571],
        [-1.0504, -1.3385, -1.0136, -2.3611, -1.4330],
        [-1.7712, -1.7886, -1.2227, -1.9236, -0.9936]])


In [None]:
a = np.array([ 0.7964,  0.9837, -0.1394,  0.5177, -0.8972])
-0.7964**2 / 2 - 0.5 * np.log(2 * np.pi) 

-1.2360650132046727