# Training a Neural Causal Model

As Robert Osazuwa Ness details in his book [causal ai](https://www.manning.com/books/causal-ai):

> We regularly work with functions in causal modeling and inference, and sometimes it makes sense to approximate them, *so long as the approximations preserve the causal information we care about.*.. [i]n this section, we’ll do this mapping between a node and its parents with the variational autoencoder (VAE) framework. We’ll train two deep neural nets in the VAE, one of which *maps parent cause variables to a distribution of the outcome variable*, and another that *maps the outcome variable to a distribution of the cause variables.*

We will use this framework to build a neural causal model of an image using the combined dataset of [MNIST](https://en.wikipedia.org/wiki/MNIST_database) and [TypefaceMNIST](https://paperswithcode.com/dataset/typography-mnist) in which a binary label represents whether the digit is handwritten(i.e., 1) or typed (i.e., 0).

In [None]:
!pip install "pyro-ppl"
!pip install "torch"
!pip install "torchvision"
!pip install "numpy==1.26.4"

## 1. Downloading and Preprocessing Data

In [None]:
import torch
USE_CUDA = True
DEVICE_TYPE = torch.device("cuda" if USE_CUDA else "cpu")

In [None]:
# class to combine the datasets
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from torchvision import transforms

class CombinedDataset(Dataset):
  def __init__(self, csv_file):
    self.dataset = pd.read_csv(csv_file)

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    images = self.dataset.iloc[idx, 3:]
    images = np.array(images, dtype="float32")/255
    images = images.reshape(28, 28)
    transform = transforms.ToTensor()
    images = transform(images)
    digits = self.dataset.iloc[idx, 2]
    digits = np.array([digits], dtype="int")
    is_handwritten = self.dataset.iloc[idx, 1]
    is_handwritten = np.array([is_handwritten], dtype='float32')
    return images, digits, is_handwritten


In [None]:
# download, split, load data
from torch.utils.data import DataLoader
from torch.utils.data import random_split
def setup_dataloaders(batch_size=64, use_cuda=USE_CUDA):
  combined_dataset = CombinedDataset(
      csv_file = "https://raw.githubusercontent.com/altdeep/causalML/master/datasets/combined_mnist_tmnist_data.csv"
  )
  kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
  n = len(combined_dataset)
  train_size = int(0.8 *n)
  test_size = n-train_size
  train_dataset, test_dataset = random_split(
        combined_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
  train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        **kwargs
    )
  test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        **kwargs
    )
  return train_loader, test_loader

## 2. Creating the Encoder and Decoder

Similar to the traditional [Variational AutoEncoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) used in machine learning the causal form is also composed of an encoder and decoder, but is used to model the *causal DAG of generating the image* (i.e.,the DGP that generates the MNST image). To that end we model the causal drivers of what determines the image we see. For this example, the main drivers of the MNST images are  1) the label/type of *digit* 2) the label of is it *handwritten* and 3) (and in my opinion the most important) the latent causes as represented by the variable Z. As Robert Osazuwa Ness details in his book [causal ai](https://www.manning.com/books/causal-ai):

>  Z appears as a new parent in the causal DAG...we view *digit* and *is-handwritten* as causal drivers of what we see in the image. Yet there are other elements of the image (e.g., the stroke thickness of a handwritten character, or the font of a typed character) that are also causes of what we see in the image. *We’ll think of Z as a continuous latent stand-in for all of these other causes of the image that we are not explicitly modeling...it is important to remember that the representation we learn for Z is a stand-in for latent causes and is not the same as learning the actual latent causes.*

Accordingly, the encoder's job is to encode the explicit variables of image, the label of *is-handwritten* and the *digit* label into a latent variable Z which when combined with the explicit variables and fed into the decoder generates the image we see.


In [None]:
# creating the Decoder
from torch import nn

class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        img_dim = 28 * 28
        digit_dim = 10
        is_handwritten_dim = 1
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
        encoding_dim = z_dim + digit_dim + is_handwritten_dim
        self.fc1 = nn.Linear(encoding_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, img_dim)

    def forward(self, z, digit, is_handwritten):
        input = torch.cat([z, digit, is_handwritten], dim=1)
        hidden = self.softplus(self.fc1(input))
        img_param = self.sigmoid(self.fc2(hidden))
        return img_param

# creating the Encoder
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        img_dim = 28 * 28
        digit_dim = 10
        is_handwritten_dim = 1
        self.softplus = nn.Softplus()
        input_dim = img_dim + digit_dim + is_handwritten_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)

    def forward(self, img, digit, is_handwritten):
        input = torch.cat([img, digit, is_handwritten], dim=1)
        hidden = self.softplus(self.fc1(input))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

## 3. Pyro Functions used to model the Causal Variational AutoEncoder(VAE)

As detailed above the decoder represents modeling the causal [Markov kernel](https://math.stackexchange.com/questions/4594349/question-about-the-definition-of-markov-kernel); that is the conditional probability of $P(X | digit, is-handwritten, Z)$ where $X$ represents the image.  The latent variable Z, the label *digit* and the label *is-handwritten* are sampled from standard canonical distributions. In this case Z is sampled from a Normal distribution, the label *digit* is sampled from a Categorical distribution, and the label *is-handwritten* is sampled from a Bernoulli distribution.

Additionally as the book details since we are modeling the joint probability distribution (i.e. $P(X, digit, is-handwritten, Z)$), the random variable X which represents the pixel values of the image needs to be accounted for too. Even though the pixel values are not technically binary outcomes, since they can take values from 0 to 255, we model them using the Bernoulli distribution due to ease of use. This helper method is named **model** in the VAE.

To apply the model to the images in the training data the helper method **training_model** is used to condition the model on the data (i.e., the labels and images). Furthermore, we want to get good samples for the latent variable Z *(since it represents the image)* we implement a guide function that uses the encoder's output and canonical distributions to sample good values for Z. This helper method is named **training_guide** in the VAE.




In [None]:
# Creating VAE
import pyro
import pyro.distributions as dist

dist.enable_validation(False)
class VAE(nn.Module):
    def __init__(
        self,
        z_dim=50,
        hidden_dim=400,
        use_cuda=USE_CUDA,
    ):
        super().__init__()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.setup_networks()

    def setup_networks(self):
        self.encoder = Encoder(self.z_dim, self.hidden_dim)
        self.decoder = Decoder(self.z_dim, self.hidden_dim)
        if self.use_cuda:
            self.cuda()

    def model(self, data_size=1):
      pyro.module("decoder", self.decoder)
      options = dict(dtype=torch.float32, device=DEVICE_TYPE)
      z_loc = torch.zeros(data_size, self.z_dim, **options)
      z_scale = torch.ones(data_size, self.z_dim, **options)
      z = pyro.sample("Z", dist.Normal(z_loc, z_scale).to_event(1))
      p_digit = torch.ones(data_size, 10, **options)/10
      digit = pyro.sample(
          "digit",
          dist.OneHotCategorical(p_digit)
      )
      p_is_handwritten = torch.ones(data_size, 1, **options)/2
      is_handwritten = pyro.sample(
          "is_handwritten",
          dist.Bernoulli(p_is_handwritten).to_event(1)
      )
      img_param = self.decoder(z, digit, is_handwritten)
      img = pyro.sample("img", dist.Bernoulli(img_param).to_event(1))
      return img, digit, is_handwritten

    def training_model(self, img, digit, is_handwritten, batch_size):
      conditioned_on_data = pyro.condition(
          self.model,
          data={
            "digit": digit,
            "is_handwritten": is_handwritten,
            "img": img
          }
      )
      with pyro.plate("data", batch_size):
          img, digit, is_handwritten = conditioned_on_data(batch_size)
      return img, digit, is_handwritten

    def training_guide(self, img, digit, is_handwritten, batch_size):
      pyro.module("encoder", self.encoder)
      options = dict(dtype=torch.float32, device=DEVICE_TYPE)
      with pyro.plate("data", batch_size):
        z_loc, z_scale = self.encoder(img, digit, is_handwritten)
        normal_dist = dist.Normal(z_loc, z_scale).to_event(1)
        z = pyro.sample("Z", normal_dist)

# 4. Training the Causal Variational Autoencoder(VAE)