# Learning with partial data

In this tutorial we detail an example on how to define a partially observed dataset compatible with MultiVae models. 

## Incomplete dataset

The MultiVae library has an simple class to handle incomplete datasets: the IncompleteDataset class that inherits from `torch.utils.data.Dataset`class. The `__getitem__` method returns a `pythae.data.DatasetOutput`with a `data`field and a `masks`field. 
Both are dictionaries containing tensors for each modality. The masks are boolean tensors with `True`values where the modality data is available.

Below we demonstrate how to initialize a simple dataset from tensors using this class. 

In [1]:
from multivae.data.datasets import IncompleteDataset, DatasetOutput
import torch

# Define random data samples
data = dict(
    modality_1 = torch.randn((100,3,16,16)),
    modality_2 = torch.randn((100, 1, 10, 10))
)
# Define random masks : masks are boolean tensors: True indicates the modality is available. 
masks = dict(
    modality_1 = torch.bernoulli(0.7*torch.ones((100,))).bool(),
    modality_2 = torch.ones((100,)).bool()
)

# Arbitrary labels (optional)
labels = torch.bernoulli(0.5*torch.ones((100,)))

dataset = IncompleteDataset(data, masks, labels)
dataset_without_labels = IncompleteDataset(data, masks)

But you can also define completely custom dataset with the same structure as the IncompleteDataset Class:
- the `__getitem__` method must return a `DatasetOutput` instance with a field `data` containing a dictionary, `masks` containing also a dictionary, and an optional `labels` field containing a tensor. 

Below, we provide a very simple example. 

In [6]:
from multivae.data.datasets import IncompleteDataset, DatasetOutput
from torch import Tensor

class my_custom_partial_dataset(IncompleteDataset):
    
    def __init__(self, shape_1, shape_2) -> None:
        
        self.shape_1 = shape_1
        self.shape_2 = shape_2
        
    def __getitem__(self, index):
        
        # return a random data point with a mask
        data = dict(
            mod1 = torch.randn(self.shape_1),
            mod2 = torch.randn(self.shape_2)
        )
        masks = dict(
            mod1 = torch.tensor([index%4==1]),
            mod2 = torch.tensor([index%4 == 3]))
            
        return DatasetOutput(data=data, masks=masks)
    
    def __len__(self):
        return 100
            

dataset = my_custom_partial_dataset((2,23,4), (1,4))


For a more realistic example, check out the MMNISTDataset class in `multivae.data.datasets.mmnist`. That dataset has five image modalities and can be initialized with partially missing data.


The following models in MultiVae can be trained using partially observed data:
- MMVAE
- MVAE
- MoPoE
- MVTCAE
- MMVAE+
- DMVAE
- CMVAE

using the exact same training process as complete dataset. 

In each batch, the losses components corresponding to missing modalities are filtered out using the provided mask in a way that respect the ELBO formulation.

![image](../../static/handling_incomplete.png)


Below we instantiate a partial PolyMNIST dataset:

In [7]:
from multivae.data.datasets.mmnist import MMNISTDataset

dataset = MMNISTDataset(data_path = '~/scratch/data/',
                        split='train',
                        # download=True,
                        missing_ratio=0.2) # Here we create missing at random blanks in the dataset

A simple example of a model trained on this incomplete dataset. 

In [9]:
# define a model

from multivae.models import MVTCAE, MVTCAEConfig

model_config = MVTCAEConfig(
    n_modalities = 5,
    input_dims= {f'm{i}' : (3,28,28) for i in range(5)},
    latent_dim=32,
)

model = MVTCAE(model_config)

# Define a trainer
from multivae.trainers import BaseTrainer, BaseTrainerConfig

training_config = BaseTrainerConfig(
    learning_rate=1e-3, 
    num_epochs=1
)

trainer = BaseTrainer(model , dataset, training_config=training_config)

trainer.train()


! No eval dataset provided ! -> keeping best model on train.

Model passed sanity check !
Ready for training.

Setting the optimizer with learning rate 0.001
Created dummy_output_dir/MVTCAE_training_2023-06-14_10-17-24. 
Training config, checkpoints and final model will be saved here.

Training params:
 - max_epochs: 1
 - per_device_train_batch_size: 64
 - per_device_eval_batch_size: 64
 - checkpoint saving every: None
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)
Scheduler: None

Successfully launched training !

Training of epoch 1/1: 100%|██████████| 938/938 [01:24<00:00, 11.06batch/s]]
--------------------------------------------------------------------------
Train loss: 9151.6141
--------------------------------------------------------------------------
Training ended!
Saved final model in dummy_ou

For a complete example of training and validation of models on the partially observed PolyMNIST dataset, scripts are provided at 
https://github.com/AgatheSenellart/nips_experiments. 