In [None]:
%cd ..

# Tutorial: Block-wise missing data generation

## Prerequisites

We will need the following libraries installed: matplotlib

## Step 1: Import required libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from imvc.ampute import Amputer
from imvc.utils import DatasetUtils
from imvc.impute import get_observed_view_indicator

In [None]:
from tueplots import axes, bundles
plt.rcParams.update({**bundles.icml2022(), **axes.lines()})

In [None]:
for key in ["axes.labelsize", "axes.titlesize", "font.size", "legend.fontsize", "xtick.labelsize", "ytick.labelsize"]:
    plt.rcParams[key] += 3

## Step 2: Load the dataset

Let's create a random multi-view dataset with 1000 samples and 5 views.

In [None]:
RANDOM_STATE = 42
n_views = 4
n_samples = 10
Xs = [pd.DataFrame(np.random.default_rng(RANDOM_STATE).random((n_samples, 10))) for i in range(n_views)]

## Step 3: Apply missing data mechanism (Amputation)

Using Amputer, we randomly introduce missing data to simulate a scenario where some modalities are missing. Here, 30% of the samples will be incomplete.

In [None]:
mechanism = "edm"
transformed_Xs = Amputer(mechanism=mechanism, p=0.8, random_state=RANDOM_STATE).fit_transform(Xs)

We can visualize which modalities are missing using a binary color map, where black means observed and white means empty (missing).

In [None]:
xlabel,ylabel = "Modality", "Samples"
observed_view_indicator = get_observed_view_indicator(transformed_Xs).sort_values(list(range(len(transformed_Xs))))
observed_view_indicator.columns = observed_view_indicator.columns + 1
plt.pcolor(observed_view_indicator, cmap="binary", edgecolors="gray")
plt.xticks(np.arange(0.5, len(observed_view_indicator.columns), 1), observed_view_indicator.columns)
_ = plt.xlabel(xlabel), plt.ylabel(ylabel)

In [None]:
names_dict = {"edm": "Equally distributed missing",
             "pm": "Partial missing",
             "mcar": "Missing completely at random",
             "mnar": "Missing not at random",
             }
for mechanism in ["edm", "pm", "mcar", "mnar"]:
    # mechanism = "mnar"
    transformed_Xs = Amputer(mechanism=mechanism, p=0.8, random_state=RANDOM_STATE).fit_transform(Xs)
    xlabel,ylabel = "Modality", "Samples"
    observed_view_indicator = get_observed_view_indicator(transformed_Xs).sort_values(list(range(len(transformed_Xs))))
    observed_view_indicator.columns = observed_view_indicator.columns + 1
    plt.pcolor(observed_view_indicator, cmap="binary_r", edgecolors="gray")
    plt.xticks(np.arange(0.5, len(observed_view_indicator.columns), 1), observed_view_indicator.columns)
    _ = plt.xlabel(xlabel), plt.ylabel(ylabel), plt.title(names_dict[mechanism])
    plt.savefig(f"paper_figures/amputation_{mechanism}.pdf")
    plt.savefig(f"paper_figures/amputation_{mechanism}.svg")

In [None]:
# mechanism = "mnar"
# transformed_Xs = Amputer(mechanism=mechanism, p=0., random_state=RANDOM_STATE).fit_transform(Xs)
# xlabel,ylabel = "Modality", "Samples"
# observed_view_indicator = get_observed_view_indicator(transformed_Xs).sort_values(list(range(len(transformed_Xs))))
# observed_view_indicator.columns = observed_view_indicator.columns + 1
# plt.pcolor(observed_view_indicator, cmap="binary", edgecolors="gray")
# plt.xticks(np.arange(0.5, len(observed_view_indicator.columns), 1), observed_view_indicator.columns)
# _ = plt.xlabel(xlabel), plt.ylabel(ylabel)
# plt.savefig(f"paper_figures/amputation_0.pdf")
# plt.savefig(f"paper_figures/amputation_0.svg")

## Step 4: Visualize different amputation mechanisms 

We will show the four different amputation mechanisms: EDM, PM, MCAR and MNAR. 

In [None]:
samples_dict = {}
fig,axs = plt.subplots(1,4, figsize= (12,2.5))
for idx, (ax, mechanism) in enumerate(zip(axs, ["edm", "pm", "mcar", "mnar"])):
    transformed_Xs = Amputer(mechanism=mechanism, p=0.8, random_state=RANDOM_STATE).fit_transform(Xs)
    observed_view_indicator = get_observed_view_indicator(transformed_Xs).sort_values(list(range(len(transformed_Xs))))
    observed_view_indicator.columns = observed_view_indicator.columns + 1
    ax.pcolor(observed_view_indicator, cmap="binary_r", edgecolors="gray")
    ax.set_title(f"Mechanism = {mechanism}")
    ax.set_xticks(np.arange(0.5, len(observed_view_indicator.columns), 1), observed_view_indicator.columns)
    ax.set_xlabel(xlabel), ax.set_ylabel(ylabel)
    if idx != 0:
        ax.get_yaxis().set_visible(False)
    samples_dict[mechanism] = DatasetUtils.get_summary(Xs=transformed_Xs)
plt.tight_layout()
plt.savefig("paper_figures/amputation.pdf")
plt.savefig("paper_figures/amputation.svg")

As shown in the below table, all samples have the same number of complete and incomplete samples. However, the amount of observed samples in each modality varies depending on the chosen missing pattern.

In [None]:
pd.DataFrame.from_dict(samples_dict, orient= "index")

In [None]:
print(pd.DataFrame.from_dict(samples_dict, orient= "index").to_latex())

In [None]:
n_views = 5
n_samples = 1000
Xs = [pd.DataFrame(np.random.default_rng(RANDOM_STATE).random((n_samples, 10))) for i in range(n_views)]
for p in np.arange(0.1, 1., 0.1):
    samples_dict = {}
    fig,axs = plt.subplots(1,4, figsize= (12,2.5))
    for idx, (ax, mechanism) in enumerate(zip(axs, ["edm", "pm", "mcar", "mnar"])):
        transformed_Xs = Amputer(mechanism=mechanism, p=p, random_state=RANDOM_STATE).fit_transform(Xs)
        observed_view_indicator = get_observed_view_indicator(transformed_Xs).sort_values(list(range(len(transformed_Xs))))
        observed_view_indicator.columns = observed_view_indicator.columns + 1
        ax.pcolor(observed_view_indicator, cmap="binary_r")
        ax.set_title(f"Mechanism = {mechanism}")
        ax.set_xticks(np.arange(0.5, len(observed_view_indicator.columns), 1), observed_view_indicator.columns)
        ax.set_xlabel(xlabel), ax.set_ylabel(ylabel)
        if idx != 0:
            ax.get_yaxis().set_visible(False)
        samples_dict[mechanism] = DatasetUtils.get_summary(Xs=transformed_Xs)
    plt.tight_layout()

    display(pd.DataFrame.from_dict(samples_dict, orient= "index"))