# Chapter 5 - Connecting Causality and Deep Learning

The notebook is a code companion to chapter 5 of the book [Causal AI](https://www.manning.com/books/causal-ai) by [Robert Osazuwa Ness](https://www.linkedin.com/in/osazuwa/).

<a href="https://colab.research.google.com/github/altdeep/causalML/blob/master/book/chapter%205/chapter_5_Connecting_Causality_and_Deep_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook was written in Google Colab using Python version 3.10.12. The versions of the main libraries include:
* pyro version 1.84
* torch version 2.2.1
* pandas version 2.0.3
* torchvision vserions 0.18.0+cu121




Pgmpy allows us to fit conventional Bayesian networks on a causal DAG. However, with modern deep probabilistic machine learning frameworks like pyro, we can build more nuanced and powerful causal models.  In this tutorial, we fit a variational autoencoder on a causal DAG that represents a dataset that mixes handwritten MNIST digits and typed T-MNIST images.

![TMNIST-MNIST](https://github.com/altdeep/causalML/blob/master/book/chapter%205/images/MNIST-TMNIST.png?raw=1)

In [1]:
# !pip install pyro-ppl==1.8.4
# !pip install torchvision

## Listing 5.1: Setup for GPU training

The code will run faster if we use CUDA, if it's available.

In [2]:
import torch    #A
USE_CUDA = True    #A
DEVICE_TYPE = torch.device("cuda" if USE_CUDA else "cpu")    #A
#A Use CUDA if it is available.

## Listing 5.2: Combining the data

First, we create a Dataset object that will combine our two datasets.

In [3]:
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
from torchvision import transforms

class CombinedDataset(Dataset):    #A
    def __init__(self, csv_file):
        self.dataset = pd.read_csv(csv_file)
        print(self.dataset.shape)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        images = self.dataset.iloc[idx, 3:]    #B
        images = np.array(images, dtype='float32')/255.  #B
        images = images.reshape(28, 28)    #B
        transform = transforms.ToTensor()    #B
        images = transform(images)    #B
        digits = self.dataset.iloc[idx, 2]    #C
        digits = np.array([digits], dtype='int')    #C
        is_handwritten = self.dataset.iloc[idx, 1]    #D
        is_handwritten = np.array([is_handwritten], dtype='float32')    #D
        return images, digits, is_handwritten    #E

#A This class loads and processes a dataset that combines the MNIST and Typeface MNIST. The output is a torch.utils.data.Dataset object.
#B Load, normalize, and reshape the images to a 28x28 pixel.
#C Get and process the digits labels, 0-9.
#D 1 for handwritten digits (MNIST) 0 for “typed’ digits (TMNIST).
#E Return tuple of the image, the digit label, and the is_handwritten label.

## Listing 5.3: Downloading, splitting and loading the data

Next, we'll download the data and create the combined dataset.

In [4]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split

def setup_dataloaders(batch_size=64, use_cuda=USE_CUDA):    #A
    combined_dataset = CombinedDataset(
"https://raw.githubusercontent.com/altdeep/causalML/master/datasets/combined_mnist_tmnist_data.csv"
    )
    n = len(combined_dataset)    #B
    train_size = int(0.8 * n)    #B
    test_size = n - train_size    #B
    train_dataset, test_dataset = random_split(    #B
        combined_dataset,    #B
        [train_size, test_size],    #B
        generator=torch.Generator().manual_seed(42)    #B
    )    #B
    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = DataLoader(    #C
        train_dataset,    #C
        batch_size=batch_size,    #C
        shuffle=True,    #C
        **kwargs    #C
    )    #C
    test_loader = DataLoader(    #C
        test_dataset,    #C
        batch_size=batch_size,    #C
        shuffle=True,    #C
        **kwargs    #C
    )    #C
    return train_loader, test_loader
#A Setup data loader that loads the data and splits it into training and test sets
#B Allot 80% of the data to training data, the remaining 20% to test data.
#C Create training and test loaders.

In [5]:
train_loader, test_loader = setup_dataloaders()

(50000, 787)


column 0 is simply index of the datapoint in the dataset, column 1 is is_handwritten, column 2 is digits, columns 3: are image values

## Listing 5.4: Implement the decoder

First, we specify a decoder. The decoder maps the latent variable Z, a variable representing the value of the digit, and a binary variable representing whether the digit is handwritten.

In [11]:
# from torch import nn

# class Decoder(nn.Module):    #A
#     def __init__(self, z_dim, hidden_dim):
#         super().__init__()
#         img_dim = 28 * 28    #B
#         digit_dim = 10    #C
#         is_handwritten_dim = 1    #D
#         self.softplus = nn.Softplus()    #E
#         self.sigmoid = nn.Sigmoid()    #E
#         encoding_dim = z_dim + digit_dim + is_handwritten_dim    #F
#         self.fc1 = nn.Linear(encoding_dim, hidden_dim)    #F
#         self.fc2 = nn.Linear(hidden_dim, img_dim)    #G

#     def forward(self, z, digit, is_handwritten):    #H
#         input = torch.cat([z, digit, is_handwritten], dim=1)   #I
#         hidden = self.softplus(self.fc1(input))    #J
#         img_param = self.sigmoid(self.fc2(hidden))    #K
#         return img_param
# #A The decoder method of a VAE class.
# #B Image is 28 by 28 pixels
# #C Digit is one-hot encoded digits 0-9, i.e., a vector of length 10.
# #D An indicator for if the digit is handwritten that has size 1
# #E The softplus and sigmoid are nonlinear transforms (activation functions) used in mapping between layers.
# #F fc1 is a linear function that maps Z vector, the digit, and the is_handwritten to a linear out, which is passed through a softplus activation function to create a "hidden layer" - a vector whose length is given by hidden_layer.
# #G The fc2 linearly maps the hidden layer to an output passed to a sigmoid function. The resulting value is a value between 0 and 1.
# #H Define the forward computation from the latent Z variable value to a generated X variable value.
# #I First combine Z and the labels.
# #J Then compute the hidden layer.
# #K Finally, pass the hidden layer to a linear transform, then to a sigmoid transform to output a parameter vector of length 784. Each element of the vector corresponds to a Bernoulli parameter value for an image pixel.


In [None]:
from torch import nn

# This is the diffusion reverse model that predicts the conditional mean of the noise epsilon_t at any time step t given the noisy image x_t and the causal conditioning variables.
# The output is a vector of size img_dim with real values that directly denote the predicted noise that was added during the forward process - not a distribution.

class Decoder(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        img_dim = 28 * 28
        digit_dim = 10
        is_handwritten_dim = 1
        t_dim = 1

        self.softplus = nn.Softplus()

        encoding_dim = img_dim + digit_dim + is_handwritten_dim + t_dim
        self.fc1 = nn.Linear(encoding_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, img_dim)

    def forward(self, x_t, digit, is_handwritten, t):
        x_t = x_t.view(x_t.size(0), -1)
        t = t.unsqueeze(1).float()
        input = torch.cat([x_t, digit, is_handwritten, t], dim=1)
        hidden = self.softplus(self.fc1(input))
        eps_hat = self.fc2(hidden)
        return eps_hat

## Forward Process of Diffusion

This section implements the forward diffusion process given the original image x_0, the alphas schedule, and the number of time steps

In [None]:
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)

def q_sample(x0, t, eps):
    sqrt_ab = alpha_bars[t].sqrt().unsqueeze(1)
    sqrt_1_ab = (1 - alpha_bars[t]).sqrt().unsqueeze(1)
    return sqrt_ab * x0 + sqrt_1_ab * eps

## Listing 5.5: The causal model

The `model` method implements the causal model. First it samples the latent variable Z, the digit variable, and the is_handwritten variable. These are passed to the decoder, which generates the image.

In [12]:
import pyro
import pyro.distributions as dist

dist.enable_validation(False)    #A
def model(self, data_size=1):    #B
    pyro.module("decoder", self.decoder)    #B
    options = dict(dtype=torch.float32, device=DEVICE_TYPE)
    z_loc = torch.zeros(data_size, self.z_dim, **options)    #C
    z_scale = torch.ones(data_size, self.z_dim, **options)    #C
    # print(z_loc.shape, z_scale.shape)
    z = pyro.sample("Z", dist.Normal(z_loc, z_scale).to_event(1))    #C - we are defining the prior belief that latent styles are normally distributed around origin
    # print(z.shape) # shape = (batch_size, z_dim)
    p_digit = torch.ones(data_size, 10, **options)/10    #D - categorical RV - prior probability distribution for the 10 digits
    digit = pyro.sample(    #D
        "digit",    #D
        dist.OneHotCategorical(p_digit)    #D
    )    #D
    p_is_handwritten = torch.ones(data_size, 1, **options)/2    #E - bernoulli RV - prior probability for MNIST vs TMNIST
    is_handwritten = pyro.sample(    #E
        "is_handwritten",    #E
        dist.Bernoulli(p_is_handwritten).to_event(1)    #E
    )    #E
    print(is_handwritten.shape)
    img_param = self.decoder(z, digit, is_handwritten)    #F - Each element of the vector corresponds to a Bernoulli parameter value for an image pixel.
    img = pyro.sample("img", dist.Bernoulli(img_param).to_event(1))  #G - samples to get a realization from the img_param distribution
    return img, digit, is_handwritten
#A Disabling distribution validation lets Pyro calculate loglikelihoods for pixels even though the pixels are not binary values.
#B The model of a single image. Within the method we register the decoder, a PyTorch module, with Pyro. This lets Pyro know about the parameters inside of the decoder network.
#C We model the joint probability of Z, digit, and is_handwritten sampling each from canonical distributions. We sample Z from a multivariate normal with location parameter z_loc (all zeros) and scale parameter z_scale (all ones).
#D We also sample the digit from a one-hot categorical distribution. Equal probability is assigned to each digit.
#E We similarly sample the is_handwritten variable from a Bernoulli.
#F The decoder maps digit, is_handwritten, and Z to a probability parameter vector.
#G That parameter vector is passed to the Bernoulli distribution, which models the pixel values in the data. The pixels are not technically Bernoulli binary variables, but we'll relax this assumption.

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import pyro
import pyro.distributions as dist

def model(self, x0, data_size=1):
    pyro.module("decoder", self.decoder)
    options = dict(dtype=torch.float32, device=DEVICE_TYPE)
    
    # sample diffusion time
    batched_probs = torch.ones(data_size, T, device=x0.device) / T
    t = pyro.sample("t", dist.Categorical(batched_probs))

    # sample noise
    eps = pyro.sample("eps", dist.Normal(0, 1).expand(x0.shape).to_event(1))
    
    p_digit = torch.ones(data_size, 10, **options)/10    #D - categorical RV - prior probability distribution for the 10 digits
    digit = pyro.sample(    #D
        "digit",    #D
        dist.OneHotCategorical(p_digit)    #D
    )    #D
    p_is_handwritten = torch.ones(data_size, 1, **options)/2    #E - bernoulli RV - prior probability for MNIST vs TMNIST
    is_handwritten = pyro.sample(    #E
        "is_handwritten",    #E
        dist.Bernoulli(p_is_handwritten).to_event(1)    #E
    )    #E

    # forward diffusion (deterministic)
    x_t = q_sample(x0, t, eps)

    # predict noise
    eps_hat = self.decoder(x_t, digit, is_handwritten, t)

    eps_realization = pyro.sample("obs_eps", dist.Normal(eps_hat, 1.0).to_event(1), obs=eps)

    return x_t, eps_realization, digit, is_handwritten

## Listing 5.6 Method for applying model to N images in data

`training_model` extends `model` towards representing each image in the dataset.

In [13]:
def training_model(self, img, digit, is_handwritten, batch_size):    #A
    conditioned_on_data = pyro.condition(    #B
        self.model,
        data={
            "digit": digit,
            "is_handwritten": is_handwritten,
            "img": img
        }
    )
    with pyro.plate("data", batch_size):    #C
        img, digit, is_handwritten = conditioned_on_data(batch_size)
    return img, digit, is_handwritten
#A The model represents the data generating process for one image. The training_model applies that model to the N images in the training data.
#B Now we condition the model on the evidence in the training data.
#C This context manager represents the N-size plate representing repeating IID examples in the data in Figure 5.9. In this case, N is the batch size. It works like a for loop iterating over each data unit in the batch.

In [None]:
def training_model(self, img, digit, is_handwritten, T, batch_size):
    # Condition the model on the labels provided by the dataset
    conditioned_on_data = pyro.condition(
        self.model,
        data={
            "digit": digit,
            "is_handwritten": is_handwritten
        }
    )
    
    # We wrap everything in the data plate
    with pyro.plate("data", batch_size):
        # We pass 'img' (x0) into the model so it can be diffused
        x_t, noise_t, digit, is_handwritten = conditioned_on_data(x0=img, data_size=batch_size)

    return x_t, noise_t, digit, is_handwritten

## Listing 5.7: Implement the encoder

The encoder takes an image, the digit, and whether the variable is handwritten, and infers the latent representation Z.

In [14]:
class Encoder(nn.Module):    #A
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        img_dim = 28 * 28    #B
        digit_dim = 10    #C
        is_handwritten_dim = 1
        self.softplus = nn.Softplus()    #D
        input_dim = img_dim + digit_dim + is_handwritten_dim    #E
        self.fc1 = nn.Linear(input_dim, hidden_dim)    #E
        self.fc21 = nn.Linear(hidden_dim, z_dim)    #F
        self.fc22 = nn.Linear(hidden_dim, z_dim)    #F

    def forward(self, img, digit, is_handwritten):    #G
        input = torch.cat([img, digit, is_handwritten], dim=1)    #H
        hidden = self.softplus(self.fc1(input))    #I
        z_loc = self.fc21(hidden)    #J
        z_scale = torch.exp(self.fc22(hidden))    #J
        return z_loc, z_scale
#A The encoder is an instance of a Pytorch module.
#B The input image is 28X28 = 784 pixels.
#C The digit dimension is 10.
#D In the encoder, we’ll only use the softplus transform (activation function).
#E The linear transform fc1 combines with the softplus to map the 784 dimensional pixel vector, 10 dimensional digit label vector, and 2 dimensional is_handwritten vector to the hidden layer.
#F The linear transforms fc21 and fc22 will combine with the softplus to map the hidden vector to Z’s vector space.
#G Define the reverse computation from an observed X variable value to a latent Z variable value.
#H Combine the image vector, digit label, and is-handwritten label into one input.
#I Map the input to the hidden layer.
#J The VAE framework will sample Z from a Normal distribution that approximates P(Z|img, digit, is_handwritten). The final transforms map the hidden layer to a location and scale parameter for that Normal distribution.

There is no encoder model for diffusion, the forward process is deterministic.

## Listing 5.8: The guide function

`training_guide` contains the encoder. The purpose of `training_guide` is to approximate P(Z|image, digit, is_handwritten) during training.

In [15]:
def training_guide(self, img, digit, is_handwritten, batch_size):    #A
    pyro.module("encoder", self.encoder)    #B
    options = dict(dtype=torch.float32, device=DEVICE_TYPE)
    with pyro.plate("data", batch_size):    #C
        z_loc, z_scale = self.encoder(img, digit, is_handwritten)    #D
        normal_dist = dist.Normal(z_loc, z_scale).to_event(1)    #D
        z = pyro.sample("Z", normal_dist)    #E
#A training_guide is a method of the VAE which will use the encoder.
#B Register the encoder so Pyro is aware of its weight parameters.
#C This is the same plate context manager for iterating over the batch data that we see in the training_model.
#D Use the encoder to map an image and its labels to parameters of a Normal distribution.
#E Sample Z from that Normal distribution.

In [None]:
def training_guide(self, img, digit, is_handwritten, T, batch_size):
    batch_size = img.size(0)    
    with pyro.plate("data", batch_size):
        # 1. We "infer" a random timestep for this training piece
        batched_probs = torch.ones(batch_size, T, device=img.device) / T
        t = pyro.sample("t", dist.Categorical(batched_probs))       
        # 2. We "infer" the exogenous noise.
        # This acts as the 'Z' (latent variable) for this specific image.
        
        eps = pyro.sample("eps", dist.Normal(0, 1).expand(img.shape).to_event(1))

## Listing 5.9: The full VAE code

Now we implement all the parts in the VAE.

In [16]:
class VAE(nn.Module):
    def __init__(
        self,
        z_dim=50,    #A
        hidden_dim=400,    #B
        use_cuda=USE_CUDA,
    ):
        super().__init__()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.setup_networks()

    def setup_networks(self):    #C
        self.encoder = Encoder(self.z_dim, self.hidden_dim)
        self.decoder = Decoder(self.z_dim, self.hidden_dim)
        if self.use_cuda:
            self.cuda()


    # Thus us where pyro gets integrated with pytorch
    model = model    #D
    training_model = training_model    #D
    training_guide = training_guide    #D

#A Setting the latent dimension to have a dimension of 50.
#B Setting the hidden layers to have a dimension of 400.
#C Setup the encoder and decoder.
#D Adding in the methods for model, training_model, and training_guide.

In [None]:
class Diffusion(nn.Module):
    def __init__(self, T=1000, hidden_dim=400, use_cuda=USE_CUDA):
        super().__init__()
        self.use_cuda = use_cuda
        self.T = T
        self.hidden_dim = hidden_dim
        
        # 1. Setup the reverse-process network (Reverse Diffusion)
        self.decoder = Decoder(hidden_dim)
        
        # 2. Setup fixed diffusion schedule (The Forward Process)
        self.setup_schedule()
        
        if self.use_cuda:
            self.cuda()

    def setup_schedule(self):
        """Precompute alpha-bars for the q_sample math."""
        self.betas = torch.linspace(1e-4, 0.02, self.T)
        self.alphas = 1.0 - self.betas
        # Register as buffer so it moves to GPU with the model
        self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))

    def q_sample(self, x0, t, eps):
        """The forward process: x0 -> xt"""
        sqrt_ab = self.alpha_bars[t].sqrt().unsqueeze(1)
        sqrt_1_ab = (1 - self.alpha_bars[t]).sqrt().unsqueeze(1)
        return sqrt_ab * x0 + sqrt_1_ab * eps

    def model(self, x0, data_size=1):
        pyro.module("decoder", self.decoder)
        device = x0.device
        options = dict(dtype=torch.float32, device=device)
        
        # 1. Sample diffusion time (Batched)
        batched_probs = torch.ones(data_size, self.T, device=device) / self.T
        t = pyro.sample("t", dist.Categorical(batched_probs))

        # 2. Sample exogenous noise (The latent 'style' equivalent)
        eps = pyro.sample("eps", dist.Normal(0, 1).expand(x0.shape).to_event(1))
        
        # 3. Sample Causal Parents (Conditioned during training)
        p_digit = torch.ones(data_size, 10, **options) / 10
        digit = pyro.sample("digit", dist.OneHotCategorical(p_digit))

        p_is_hw = torch.ones(data_size, 1, **options) / 2
        is_handwritten = pyro.sample("is_handwritten", dist.Bernoulli(p_is_hw).to_event(1))

        # 4. Forward diffusion (Causal path)
        x_t = self.q_sample(x0, t, eps)

        # 5. Predict noise (Mechanism)
        eps_hat = self.decoder(x_t, digit, is_handwritten, t)

        # 6. Observation (Log-likelihood matching VAE's Bernoulli line)
        pyro.sample("obs_eps", dist.Normal(eps_hat, 1.0).to_event(1), obs=eps)

        return x_t, eps, digit, is_handwritten

    def training_model(self, img, digit, is_handwritten, batch_size):
        conditioned_on_data = pyro.condition(
            self.model,
            data={"digit": digit, "is_handwritten": is_handwritten}
        )
        with pyro.plate("data", batch_size):
            # Pass image as x0 directly
            return conditioned_on_data(x0=img, data_size=batch_size)

    def training_guide(self, img, digit, is_handwritten, batch_size):
        """The guide mirrors the sampling of latents (t and eps)."""
        with pyro.plate("data", batch_size):
            batched_probs = torch.ones(batch_size, self.T, device=img.device) / self.T
            pyro.sample("t", dist.Categorical(batched_probs))
            pyro.sample("eps", dist.Normal(0, 1).expand(img.shape).to_event(1))

## Listing 5.10 Helper function for plotting images

The following utility functions helps us visualize progress during training.

In [17]:
def plot_image(img, title=None):    #A
    fig = plt.figure()
    plt.imshow(img.cpu(), cmap='Greys_r', interpolation='nearest')
    if title is not None:
        plt.title(title)
    plt.show()
#A Helper function for plotting an image

## Listing 5.11: Define a helper functions for reconstructing and viewing the images

These additional utility functions help us selected and reshape images, as well as generate new images.

In [None]:
import matplotlib.pyplot as plt

In [18]:
# import matplotlib.pyplot as plt

# def reconstruct_img(vae, img, digit, is_hw, use_cuda=USE_CUDA):    #A
#     img = img.reshape(-1, 28 * 28)
#     digit = F.one_hot(torch.tensor(digit), 10)
#     is_hw = torch.tensor(is_hw).unsqueeze(0)
#     if use_cuda:
#         img = img.cuda()
#         digit = digit.cuda()
#         is_hw = is_hw.cuda()
#     z_loc, z_scale = vae.encoder(img, digit, is_hw)
#     z = dist.Normal(z_loc, z_scale).sample()
#     img_expectation = vae.decoder(z, digit, is_hw)
#     return img_expectation.squeeze().view(28, 28).detach()

# def compare_images(img1, img2):    #B
#     fig = plt.figure()
#     ax0 = fig.add_subplot(121)
#     plt.imshow(img1.cpu(), cmap='Greys_r', interpolation='nearest')
#     plt.axis('off')
#     plt.title('original')
#     ax1 = fig.add_subplot(122)
#     plt.imshow(img2.cpu(), cmap='Greys_r', interpolation='nearest')
#     plt.axis('off')
#     plt.title('reconstruction')
#     plt.show()
# #A Given an input image, "reconstructs" the image by passing through the encoder then through the decoder.
# #B Plots two images side by side for comparison.

In [1]:
import torch.nn.functional as F

def reconstruct_img(diffusion, img, digit, is_hw, t_val=500, use_cuda=USE_CUDA):
    # 1. Prepare Inputs
    img = img.reshape(1, -1) # Batch size of 1
    digit_tensor = F.one_hot(torch.tensor([digit]), 10).float()
    is_hw_tensor = torch.tensor([[is_hw]]).float()
    # For reconstruction, we pick a fixed t (e.g., mid-way through diffusion)
    t_tensor = torch.tensor([t_val])
    
    if use_cuda:
        img = img.cuda()
        digit_tensor = digit_tensor.cuda()
        is_hw_tensor = is_hw_tensor.cuda()
        t_tensor = t_tensor.cuda()

    # 2. Add Noise (The Forward Process)
    # We generate a random noise epsilon to 'corrupt' the image
    eps = torch.randn_like(img)
    x_t = diffusion.q_sample(img, t_tensor, eps)

    # 3. Predict the Noise (The Reverse Process)
    # The decoder tries to 'see through' the noise
    eps_hat = diffusion.decoder(x_t, digit_tensor, is_hw_tensor, t_tensor)

    # 4. Mathematical Reconstruction
    # Based on the DDPM formula: x0 = (x_t - sqrt(1-alpha_bar)*eps) / sqrt(alpha_bar)
    # However, a simpler 'visual' reconstruction is just showing what the model 
    # thinks the noise was. Let's calculate the predicted x0:
    
    sqrt_alpha_bar = diffusion.alpha_bars[t_tensor].sqrt()
    sqrt_one_minus_alpha_bar = (1 - diffusion.alpha_bars[t_tensor]).sqrt()
    
    img_reconstructed = (x_t - sqrt_one_minus_alpha_bar * eps_hat) / sqrt_alpha_bar

    return img_reconstructed.squeeze().view(28, 28).detach()

def compare_images(original, noisy, reconstruction):
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    images = [original, noisy, reconstruction]
    titles = ['Original', 'Corrupted (t=500)', 'Reconstructed']
    
    for ax, img, title in zip(axes, images, titles):
        ax.imshow(img.cpu(), cmap='Greys_r', interpolation='nearest')
        ax.set_title(title)
        ax.axis('off')
        
    plt.show()

NameError: name 'USE_CUDA' is not defined

## Listing 5.12: Data processing helper functions for training

Next, we'll create some helper functions for handling the data. We'll use `get_random_example` to grab random images from the dataset. `reshape_data` will convert an image and its labels into input for the encoder. We'll use `generate_data` and `generate_coded_data` will simulate an image from the model.

In [19]:
# import torch.nn.functional as F

# def get_random_example(loader):    #A
#     random_idx = np.random.randint(0, len(loader.dataset))    #A
#     img, digit, is_handwritten = loader.dataset[random_idx]    #A
#     return img.squeeze(), digit, is_handwritten    #A

# def reshape_data(img, digit, is_handwritten):    #B
#     digit = F.one_hot(digit, 10).squeeze()    #B
#     img = img.reshape(-1, 28*28)    #B
#     return img, digit, is_handwritten    #B

# def generate_coded_data(vae, use_cuda=USE_CUDA):    #C
#     z_loc = torch.zeros(1, vae.z_dim)    #C
#     z_scale = torch.ones(1, vae.z_dim)    #C
#     z = dist.Normal(z_loc, z_scale).to_event(1).sample()    #C
#     p_digit = torch.ones(1, 10)/10    #C
#     digit = dist.OneHotCategorical(p_digit).sample()    #C
#     p_is_handwritten = torch.ones(1, 1)/2    #C
#     is_handwritten = dist.Bernoulli(p_is_handwritten).sample()    #C
#     if use_cuda:    #C
#         z = z.cuda()
#         digit = digit.cuda()
#         is_handwritten = is_handwritten.cuda()    #C
#     img = vae.decoder(z, digit, is_handwritten)    #C
#     return img, digit, is_handwritten    #C

# def generate_data(vae, use_cuda=USE_CUDA):    #D
#     img, digit, is_handwritten = generate_coded_data(vae, use_cuda)    #D
#     img = img.squeeze().view(28, 28).detach()    #D
#     digit = torch.argmax(digit, 1)    #D
#     is_handwritten = torch.argmax(is_handwritten, 1)    #D
#     return img, digit, is_handwritten    #D
# #A Chose a random example from the dataset.
# #B Reshape the data.
# #C Generate data that is encoded.
# #D Generate (unencoded) data.

In [2]:
import torch.nn.functional as F

def get_random_example(loader):    #A
    random_idx = np.random.randint(0, len(loader.dataset))    #A
    img, digit, is_handwritten = loader.dataset[random_idx]    #A
    return img.squeeze(), digit, is_handwritten    #A

def reshape_data(img, digit, is_handwritten):    #B
    digit = F.one_hot(digit, 10).squeeze()    #B
    img = img.reshape(-1, 28*28)    #B
    return img, digit, is_handwritten    #B

@torch.no_grad()
def p_sample(diffusion, x_t, t, digit, is_hw):
    """
    Reverse diffusion step: samples x_{t-1} given x_t.
    """
    # 1. Predict the noise using the decoder
    eps_hat = diffusion.decoder(x_t, digit, is_hw, t)

    # 2. Calculate coefficients for the reverse step
    # Based on the DDPM math: x_{t-1} = 1/sqrt(alpha) * (x_t - beta/sqrt(1-ab) * eps_hat)
    beta_t = diffusion.betas[t]
    sqrt_one_minus_alpha_bar_t = (1 - diffusion.alpha_bars[t]).sqrt()
    sqrt_alpha_t = diffusion.alphas[t].sqrt()

    # The predicted mean of x_{t-1}
    mean = (1 / sqrt_alpha_t) * (x_t - (beta_t / sqrt_one_minus_alpha_bar_t) * eps_hat)

    if t == 0:
        return mean
    else:
        # Add a bit of 'fresh' noise (Langevin dynamics) to keep generation diverse
        noise = torch.randn_like(x_t)
        sigma_t = beta_t.sqrt() # Standard DDPM variance choice
        return mean + sigma_t * noise
    
def generate_coded_data(diffusion, use_cuda=USE_CUDA):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    # 1. Start with Pure Noise (The Diffusion 'Latent')
    img_shape = (1, 28 * 28)
    x = torch.randn(img_shape, device=device)

    # 2. Sample Causal Parents (Digit and Style)
    p_digit = torch.ones(1, 10, device=device) / 10
    digit = dist.OneHotCategorical(p_digit).sample()
    
    p_is_hw = torch.ones(1, 1, device=device) / 2
    is_handwritten = dist.Bernoulli(p_is_hw).sample()

    # 3. The Reverse Loop (The 'Generation' work)
    # We step from t=999 all the way down to t=0
    for t_idx in reversed(range(diffusion.T)):
        t_tensor = torch.tensor([t_idx], device=device)
        x = p_sample(diffusion, x, t_tensor, digit, is_handwritten)

    return x, digit, is_handwritten

def generate_data(diffusion, use_cuda=USE_CUDA):
    # This remains the wrapper that cleans up shapes for plotting
    img, digit, is_handwritten = generate_coded_data(diffusion, use_cuda)
    
    img = img.squeeze().view(28, 28).detach()
    digit = torch.argmax(digit, 1)
    # is_handwritten for Bernoulli is usually just the value
    is_handwritten = is_handwritten.squeeze().round().int() 
    
    return img, digit, is_handwritten

NameError: name 'torch' is not defined

## Listing 5.13: Set up the training procedure

Next we set up traing. The training objective `Trace_ELBO` simultaneously trains the parameters of the encoder and the decoder. It focuses on minimizing reconstruction error (how much information is lost when an image encoded, and then decoded once again) and [KL-divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the distribution modeled by the guide (the variational distribution) and the P(Z|image, is_handwritten, digit).

In [20]:
# from pyro.infer import SVI, Trace_ELBO
# from pyro.optim import Adam

# pyro.clear_param_store()    #A
# vae = VAE()    #B
# train_loader, test_loader = setup_dataloaders(batch_size=256)    #C
# svi_adam = Adam({"lr": 1.0e-3})    #D
# model = vae.training_model    #E
# guide = vae.training_guide    #E
# svi = SVI(model, guide, svi_adam, loss=Trace_ELBO())    #E
# #A Clear any values of the parameters in the guide memory.
# #B Initalize the VAE
# #C Load the data
# #D Initialize the optizer
# #E Initialize the SVI loss calculator. Loss negative "expected lower bound" (ELBO).

(50000, 787)


In [None]:
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# 1. Clear previous parameters
pyro.clear_param_store()

# 2. Initialize the Diffusion Class
T_steps = 1000
diffusion_model = Diffusion(T=T_steps, hidden_dim=400)
train_loader, test_loader = setup_dataloaders(batch_size=256)

# 3. Setup Optimizer
# We use a slightly lower learning rate often preferred for Diffusion
optimizer = Adam({"lr": 1.0e-3})

# 4. Define Model and Guide
# We point SVI to the 'training' versions we wrote
model = diffusion_model.training_model
guide = diffusion_model.training_guide

# 5. Initialize SVI
# Trace_ELBO works here because our 'model' has an 'obs=' statement 
# and our 'guide' samples the same latent variables (t, eps).
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

## Listing 5.14: Setting up a test evaluation procedure

When training generative models, it is useful to setup a procedure that uses test data to evaluate how well training is progressing. You can include anything you think is useful to monitor during training. Here, I calculate and print the loss function on the test data, just to make sure test loss is progressively decreasing along with training loss (a flattening of test loss while training loss continued to decrease would indicate overfitting).

But a more direct signal at how well our model is training is to generate and view images. In my test evaluation procedure, I produce two visualizations. First, I inspect how well it can reconstruct a random image from the test data. I pass the image through the encoder then through the decoder, creating a “reconstruction” of the image. Then I plot the original and reconstructed image side-by-side and compare them visually, looking to see that they are close to identical.

Next, I visualize how well it is performing as an overall generative model by generating and plotting an image from scratch. I run this code once each time a certain number of epochs are run.

In [21]:
# def test_epoch(vae, test_loader):
#     epoch_loss_test = 0    #A
#     for img, digit, is_hw in test_loader:    #A
#         batch_size = img.shape[0]    #A
#         if USE_CUDA:    #A
#             img = img.cuda()    #A
#             digit = digit.cuda()    #A
#             is_hw = is_hw.cuda()    #A
#         img, digit, is_hw = reshape_data(    #A
#             img, digit, is_hw    #A
#         )    #A
#         epoch_loss_test += svi.evaluate_loss(    #A
#             img, digit, is_hw, batch_size    #A
#         )    #A
#     test_size = len(test_loader.dataset)    #A
#     avg_loss = epoch_loss_test/test_size    #A
#     print("Epoch: {} avg. test loss: {}".format(epoch, avg_loss))    #A
#     print("Comparing a random test image to its reconstruction:")    #B
#     random_example = get_random_example(test_loader)    #B
#     img_r, digit_r, is_hw_r = random_example    #B
#     img_recon = reconstruct_img(vae, img_r, digit_r, is_hw_r)    #B
#     compare_images(img_r, img_recon)    #B
#     print("Generate a random image from the model:")    #C
#     img_gen, digit_gen, is_hw_gen = generate_data(vae)    #C
#     plot_image(img_gen, "Generated Image")    #C
#     print("Intended digit: ", int(digit_gen))    #C
#     print("Intended as handwritten: ", bool(is_hw_gen == 1))    #C
# #A Calculate and print test loss.
# #B Compare a random test image to its reconstruction.
# #C Generate a random image from the model.

In [None]:
def test_epoch(diffusion, test_loader, epoch):
    epoch_loss_test = 0
    
    # 1. Calculate Test Loss (Quantitative Check)
    for img, digit, is_hw in test_loader:
        batch_size = img.shape[0]
        if USE_CUDA:
            img, digit, is_hw = img.cuda(), digit.cuda(), is_hw.cuda()
            
        img, digit, is_hw = reshape_data(img, digit, is_hw)
        
        # evaluate_loss does everything step() does but WITHOUT updating weights
        epoch_loss_test += svi.evaluate_loss(img, digit, is_hw, batch_size)
        
    test_size = len(test_loader.dataset)
    avg_loss = epoch_loss_test / test_size
    print(f"Epoch: {epoch} avg. test loss: {avg_loss:.4f}")

    # 2. Visualize Reconstruction (Qualitative Check: Denoising)
    print("Comparing a random test image to its (one-step) reconstruction:")
    img_r, digit_r, is_hw_r = get_random_example(test_loader)
    
    # We use a mid-range t (e.g., 400) to see if it can recover from significant noise
    t_test = 400 
    img_recon = reconstruct_img(diffusion, img_r, digit_r, is_hw_r, t_val=t_test)
    
    # You might want to modify compare_images to show the original vs reconstruction
    compare_images(img_r.view(28, 28), img_recon)

    # 3. Generate New Image (Qualitative Check: Full Reverse Process)
    print("Generate a brand new image by sampling from pure noise:")
    # This calls your generate_data function with the 1000-step loop
    img_gen, digit_gen, is_hw_gen = generate_data(diffusion)
    
    plot_image(img_gen, f"Generated: Digit {int(digit_gen)}")
    print(f"Intended digit: {int(digit_gen)}")
    print(f"Intended as handwritten: {bool(is_hw_gen == 1)}")

## Listing 5.15: Running training and plotting progress

Finally, we run training.

In [None]:
# NUM_EPOCHS = 2500
# TEST_FREQUENCY = 10

# train_loss = []
# train_size = len(train_loader.dataset)

# for epoch in range(0, NUM_EPOCHS+1):    #A
#     loss = 0
#     for img, digit, is_handwritten in train_loader:
#         batch_size = img.shape[0]
#         if USE_CUDA:
#             img = img.cuda()
#             digit = digit.cuda()
#             is_handwritten = is_handwritten.cuda()
#         img, digit, is_handwritten = reshape_data(
#             img, digit, is_handwritten
#         )
#         loss += svi.step(    #B
#             img, digit, is_handwritten, batch_size    #B
#         )    #B
#     avg_loss = loss / train_size
#     print("Epoch: {} avgs training loss: {}".format(epoch, loss))
#     train_loss.append(avg_loss)
#     if epoch % TEST_FREQUENCY == 0:    #C
#         test_epoch(vae, test_loader)    #C
# #A Run the training procedure for a certain number of epochs.
# #B Run a training step on one batch in one epoch.
# #C The test data evaluation procedure runs every 10 epochs.

In [None]:
NUM_EPOCHS = 2500
TEST_FREQUENCY = 10  # Note: Generation is slow, so you might increase this later

train_loss = []
train_size = len(train_loader.dataset)

for epoch in range(0, NUM_EPOCHS + 1):
    loss = 0
    for img, digit, is_handwritten in train_loader:
        batch_size = img.shape[0]
        
        if USE_CUDA:
            img = img.cuda()
            digit = digit.cuda()
            is_handwritten = is_handwritten.cuda()
            
        # Re-use your existing reshape_data function
        img, digit, is_handwritten = reshape_data(img, digit, is_handwritten)
        
        # SVI.step logic:
        # This calls: training_model(img, digit, is_handwritten, batch_size)
        # And: training_guide(img, digit, is_handwritten, batch_size)
        loss += svi.step(img, digit, is_handwritten, batch_size)
        
    avg_loss = loss / train_size
    train_loss.append(avg_loss)
    
    print(f"Epoch: {epoch} | Avg Training Loss (ELBO): {avg_loss:.4f}")

    # Run the qualitative and quantitative test procedure
    if epoch % TEST_FREQUENCY == 0:
        # Note: We pass 'diffusion_model' (your class instance) instead of 'vae'
        test_epoch(diffusion_model, test_loader, epoch)

We can continue to use `generate_data` to generate from the model once we've trained it. Finally, we can save the resulting model.

In [None]:
#torch.save(vae.state_dict(), 'mnist_tmnist_weights.pt')