# Chapter 5 - Connecting Causality and Deep Learning

The notebook is a code companion to chapter 5 of the book [Causal AI](https://www.manning.com/books/causal-ai) by [Robert Osazuwa Ness](https://www.linkedin.com/in/osazuwa/).

<a href="https://colab.research.google.com/github/altdeep/causalML/blob/master/book/chapter%205/chapter_5_Connecting_Causality_and_Deep_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

pgmpy allows us to fit conventional Bayesian networks on a causal DAG. However, with modern deep probabilistic machine learning frameworks like pyro, we can build more nuanced and powerful causal models.  In this tutorial, we fit a variational autoencoder on a causal DAG that represents a dataset that mixes handwritten MNIST digits and typed T-MNIST images. 

![TMNIST-MNIST](images/MNIST-TMNIST.png)

In [1]:
#!pip install pyro-ppl==1.8.4

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import random_split
from torchvision import transforms

ModuleNotFoundError: No module named 'torchvision'

In [None]:
USE_CUDA = False
DEVICE_TYPE = torch.device("cuda" if USE_CUDA else "cpu")
BATCH_SIZE = 256
LEARNING_RATE = 1.0e-3
NUM_EPOCHS = 2500
TEST_FREQUENCY = 10
pyro.distributions.enable_validation(False)
REINIT_PARAMS = True

First, we download the data and combine it into a Dataset object.

In [None]:
class CombinedDataset(Dataset):
    def __init__(self, csv_file):
        self.dataset = pd.read_csv(csv_file)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        images = self.dataset.iloc[idx, 3:]
        images = np.array(images, dtype='float32')/255.
        images = images.reshape(28, 28)
        transform = transforms.ToTensor()
        images = transform(images)
        digits = self.dataset.iloc[idx, 2]
        digits = np.array([digits], dtype='int')
        is_handwritten = self.dataset.iloc[idx, 1]
        is_handwritten = np.array([is_handwritten], dtype='float32')
        return images, digits, is_handwritten

def setup_dataloaders(batch_size=64, use_cuda=USE_CUDA):
    combined_dataset = CombinedDataset(
        "https://raw.githubusercontent.com/altdeep/causalML/master/datasets/combined_mnist_tmnist_data.csv"
    )
    n = len(combined_dataset)
    train_size = int(0.8 * n)
    test_size = n - train_size
    train_dataset, test_dataset = random_split(
        combined_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        **kwargs
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        **kwargs
    )
    return train_loader, test_loader

First, we specify and encoder and a decoder. The decoder maps the latent variable Z, a variable representing the value of the digit, and a binary variable representing whether the digit is handwritten The encoder takes an image, the digit, and whether the variable is handwritten, and infers the latent representation Z.

In [None]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        img_dim = 28 * 28
        digit_dim = 10
        is_handwritten_dim = 1
        self.fc1 = nn.Linear(z_dim + digit_dim + is_handwritten_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, img_dim)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, z, digit, is_handwritten):
        input = torch.cat([z, digit, is_handwritten], dim=1)
        hidden = self.softplus(self.fc1(input))
        img_param = self.sigmoid(self.fc2(hidden))
        return img_param

class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        img_dim = 28 * 28
        digit_dim = 10
        is_handwritten_dim = 1
        self.fc1 = nn.Linear(img_dim + digit_dim + is_handwritten_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.softplus = nn.Softplus()
    
    def forward(self, img, digit, is_handwritten):
        input = torch.cat([img, digit, is_handwritten], dim=1)
        hidden = self.softplus(self.fc1(input))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

Next, we implement the variational autoencoder. The `model` method implements the causal model. First it samples the latent variable Z, the digit variable, and the is_handwritten variable. These are passed to the decoder, which generates the image.

`training_model` extends `model` towards representing each image in the dataset. `training_guide` contains the encoder. The purpose of `training_guide` is to represent the approximating distribution during variational training.

In [None]:
class VAE(nn.Module):
    def __init__(
        self,
        z_dim=50,
        hidden_dim=400,
        use_cuda=USE_CUDA,
    ):
        super().__init__()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.setup_networks()
    
    def setup_networks(self):
        self.encoder = Encoder(self.z_dim, self.hidden_dim)
        self.decoder = Decoder(self.z_dim, self.hidden_dim)
        if self.use_cuda:
            self.cuda()
    
    def model(self, data_size=1):
        pyro.module("decoder", self.decoder)
        options = dict(dtype=torch.float32, device=DEVICE_TYPE)
        z_loc = torch.zeros(data_size, self.z_dim, **options)
        z_scale = torch.ones(data_size, self.z_dim, **options)
        z = pyro.sample("Z", dist.Normal(z_loc, z_scale).to_event(1))
        p_digit = torch.ones(data_size, 10, **options)/10
        digit = pyro.sample(
            "digit",
            dist.OneHotCategorical(p_digit)
        )
        p_is_handwritten = torch.ones(data_size, 1, **options)/2
        is_handwritten = pyro.sample(
            "is_handwritten",
            dist.Bernoulli(p_is_handwritten).to_event(1)
        )
        img_param = self.decoder(z, digit, is_handwritten)
        img = pyro.sample("img", dist.Bernoulli(img_param).to_event(1))
        return img, digit, is_handwritten
    
    def training_model(self, img, digit, is_handwritten, batch_size):
        model_conditioned_on_data = pyro.condition(
            self.model,
            data={
                "digit": digit,
                "is_handwritten": is_handwritten,
                "img": img
            }
        )
        with pyro.plate("data", batch_size):
            img, digit, is_handwritten = model_conditioned_on_data(batch_size)
        return img, digit, is_handwritten
    
    def training_guide(self, img, digit, is_handwritten, batch_size):
        pyro.module("encoder", self.encoder)
        options = dict(dtype=torch.float32, device=DEVICE_TYPE)
        with pyro.plate("data", batch_size):
            z_loc, z_scale = self.encoder(img, digit, is_handwritten)
            z = pyro.sample("Z", dist.Normal(z_loc, z_scale).to_event(1))

The following utility functions helps us visualize progress during training.

In [None]:
def plot_image(img, title=None):
    fig = plt.figure()
    plt.imshow(img.cpu(), cmap='Greys_r', interpolation='nearest')
    if title is not None:
        plt.title(title)
    plt.show()

def reconstruct_img(vae, img, digit, is_handwritten, use_cuda=USE_CUDA):
    img = img.reshape(-1, 28 * 28)
    digit = F.one_hot(torch.tensor(digit), 10)
    is_handwritten = torch.tensor(is_handwritten_rng).unsqueeze(0)
    if use_cuda:
      img, digit, is_handwritten = img.cuda(), digit.cuda(), is_handwritten.cuda()
    z_loc, z_scale = vae.encoder(img, digit, is_handwritten)
    z = dist.Normal(z_loc, z_scale).sample()
    img_expectation = vae.decoder(z, digit, is_handwritten)
    return img_expectation.squeeze().view(28, 28).detach()

def compare_images(img1, img2):
    fig = plt.figure()
    ax0 = fig.add_subplot(121)
    plt.imshow(img1.cpu(), cmap='Greys_r', interpolation='nearest')
    plt.axis('off')
    plt.title('original')
    ax1 = fig.add_subplot(122)
    plt.imshow(img2.cpu(), cmap='Greys_r', interpolation='nearest')
    plt.axis('off')
    plt.title('reconstruction')
    plt.show()

These additional utility functions help us selected and reshape images, as well as generate new images.

In [None]:
def get_random_example(loader):    
    random_idx = np.random.randint(0, len(loader.dataset))
    img, digit, is_handwritten = loader.dataset[random_idx]
    return img.squeeze(), digit, is_handwritten

def reshape_data(img, digit, is_handwritten):
    digit = F.one_hot(digit, 10).squeeze()
    img = img.reshape(-1, 28*28)
    return img, digit, is_handwritten

def generate_coded_data(vae, use_cuda=USE_CUDA):
    z_loc = torch.zeros(1, vae.z_dim)
    z_scale = torch.ones(1, vae.z_dim)
    z = dist.Normal(z_loc, z_scale).to_event(1).sample()
    p_digit = torch.ones(1, 10)/10
    digit = dist.OneHotCategorical(p_digit).sample()
    p_is_handwritten = torch.ones(1, 1)/2
    is_handwritten = dist.Bernoulli(p_is_handwritten).sample()
    if use_cuda:
        z, digit, is_handwritten = z.cuda(), digit.cuda(), is_handwritten.cuda()
    img = vae.decoder(z, digit, is_handwritten)
    return img, digit, is_handwritten

def generate_data(vae, use_cuda=USE_CUDA):
    img, digit, is_handwritten = generate_coded_data(vae, use_cuda)
    img = img.squeeze().view(28, 28).detach()
    digit = torch.argmax(digit, 1)
    is_handwritten = torch.argmax(is_handwritten, 1)
    return img, digit, is_handwritten

Finally, we run training. The training objective simultaneously trains the parameters of the encoder and the decoder. It focuses on minimizing reconstruction loss, meaning how much information is lost when an image encoded, and then decoded once again.

In [None]:
pyro.clear_param_store()
vae = VAE()

train_loader, test_loader = setup_dataloaders(batch_size=BATCH_SIZE, use_cuda=USE_CUDA)
train_size = len(train_loader.dataset)
test_size = len(test_loader.dataset)

svi_adam = pyro.optim.Adam({"lr": LEARNING_RATE})
svi = SVI(vae.training_model, vae.training_guide, svi_adam, loss=Trace_ELBO())
train_loss, test_loss = [], []

for epoch in range(0, NUM_EPOCHS+1):
    epoch_loss_train = 0
    for batch_idx, (img, digit, is_handwritten) in enumerate(train_loader):
        batch_size = img.shape[0]
        if USE_CUDA:
            img, digit, is_handwritten = img.cuda(), digit.cuda(), is_handwritten.cuda()
        img, digit, is_handwritten = reshape_data(img, digit, is_handwritten)
        epoch_loss_train += svi.step(img, digit, is_handwritten, batch_size)
    epoch_loss_train = epoch_loss_train / train_size
    print("Epoch: {} average training loss: {}".format(epoch, epoch_loss_train))
    train_loss.append(epoch_loss_train)
    if epoch % TEST_FREQUENCY == 0:
        epoch_loss_test = 0
        for batch_idx, (img, digit, is_handwritten) in enumerate(test_loader):
            batch_size = img.shape[0]
            if USE_CUDA:
                img, digit, is_handwritten = img.cuda(), digit.cuda(), is_handwritten.cuda()
            img, digit, is_handwritten = reshape_data(img, digit, is_handwritten)
            epoch_loss_test += svi.evaluate_loss(img, digit, is_handwritten, batch_size)
        epoch_loss_test = epoch_loss_test/test_size
        print("Epoch: {} average test loss: {}".format(epoch, epoch_loss_test))
        print("Comparing a random test image to its reconstruction:")
        img_rng, digit_rng, is_handwritten_rng = get_random_example(test_loader)
        img_recon = reconstruct_img(vae, img_rng, digit_rng, is_handwritten_rng)
        compare_images(img_rng, img_recon)
        print("Generate a random image from the model:")
        img_gen, digit_gen, is_handwritten_gen = generate_data(vae)
        plot_image(img_gen, "Generated Image")
        print("Intended digit: ", int(digit_gen))
        print("Intended as handwritten: ", bool(is_handwritten_gen == 1))
#Plot training loss
plt.plot(range(len(train_loss)), [-x for x in train_loss])
plt.ylabel('Loss')
plt.show()

We can continue to use `generate_data` to generate from the model once we've trained it. Finally, we can save the resulting model.

In [None]:

#torch.save(vae.state_dict(), 'mnist_tmnist_weights_March11.pt')