# Example: Deep structural causal model counterfactuals

In [None]:
from typing import Dict, List, Optional, Tuple, Union, TypeVar

import torch
from torch.utils.data import DataLoader

import pyro
import pyro.distributions as dist
from pyro.poutine import condition, reparam
from pyro.nn import PyroParam, PyroSample, PyroModule
import pyro.distributions.transforms as Transforms
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.infer import config_enumerate
from pyro.distributions import constraints


import causal_pyro
from causal_pyro.query.do_messenger import do
from causal_pyro.counterfactual.handlers import Factual, MultiWorldCounterfactual, TwinWorldCounterfactual
from causal_pyro.reparam.soft_conditioning import TransformInferReparam

import pandas as pd
import skimage
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import gzip
import struct
import numpy as np
import matplotlib.pyplot as plt

In [None]:
available = torch.cuda.is_available()
curr_device = torch.device("cpu") if not available else torch.cuda.current_device()
print(f'Cuda available: {available}')
print(f'Current device: {curr_device}')

if available:
    device_count = torch.cuda.device_count() 
    device_name =  torch.cuda.get_device_name(0)
    print(f'Device count: {device_count}')
    print(f'Device name: {device_name}')

## Background: Normalizing flows and counterfactuals

Much of the causal inference literature has focused on relatively simple
causal models with low dimensional data. In order to perform
counterfactual reasoning in more complex domains with high dimensional
data, Palowski et al. [@pawlowski2020deep] introduced *deep structural
causal models* (Deep SCMs): SCMs with neural networks as the functional
mechanisms between variables.

Specifically, the neural networks are
*normalizing flows*. A normalizing flow transforms a base probability
distribution (often a simple distribution, such as a multivariate
Gaussian) through a sequence of invertible transformations into a more
complex distribution (such as a distribution over images). When used
within a Deep SCM, the flow's base distribution is an exogenous noise
variable, and its output is an endogenous variable.

A salient property
of normalizing flows is that computing the likelihood of data can be
done both exactly and efficiently, and hence training a flow to model a
data distribution through maximum likelihood is straightforward. In
addition, the inverse of a normalizing flow can also typically be
efficiently computed, which renders the abduction step of a
counterfactual---inferring the posterior over exogenous variables given
evidence---trivial.

## Example: Morpho-MNIST

We consider a synthetic dataset based on MNIST, where the image of each digit ($X$) depends on stroke thickness ($T$) and brightness ($I$) of the image and the thickness depends on brightness as well.

We assume we know full causal structure (i.e., there are no unconfounded variables).

In [None]:
def load_idx(path: str) -> np.ndarray:
    """Reads an array in IDX format from disk.
    Parameters
    ----------
    path : str
        Path of the input file. Will uncompress with `gzip` if path ends in '.gz'.
    Returns
    -------
    np.ndarray
        Output array of dtype ``uint8``.
    References
    ----------
    http://yann.lecun.com/exdb/mnist/
    """
    open_fcn = gzip.open if path.endswith('.gz') else open
    with open_fcn(path, 'rb') as f:
        idx_dtype, ndim = struct.unpack('BBBB', f.read(4))[2:]
        shape = struct.unpack('>' + 'I' * ndim, f.read(4 * ndim))
        buffer_length = int(np.prod(shape))
        data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape).astype(np.float32)
        return data
    
path = os.path.join(os.getcwd(), "../datasets/morphomnist/")
metrics = pd.read_csv(path + "train-morpho.csv", index_col= 'index')
# raw_labels = load_idx(path+"train-labels-idx1-ubyte.gz")
raw_images = load_idx(path+"train-images-idx3-ubyte.gz")

thickness = torch.tensor(metrics["thickness"], dtype=torch.float32, device=curr_device)
intensity = torch.tensor(metrics["intensity"], dtype=torch.float32, device=curr_device)
# labels = torch.tensor(raw_labels, dtype=torch.float32)

In [None]:
fig = plt.figure()
rows = 2
columns = 2
fig.add_subplot(rows, columns, 1)
plt.imshow(raw_images[0])
plt.axis('off')
fig.add_subplot(rows, columns, 2)
plt.imshow(raw_images[1])
plt.axis('off')
fig.add_subplot(rows, columns, 3)
plt.imshow(raw_images[2])
plt.axis('off')
fig.add_subplot(rows, columns, 4)
plt.imshow(raw_images[3])
plt.axis('off')

### Downsampling images:

In [None]:
images = skimage.measure.block_reduce(raw_images, block_size=(1, 2, 2))
images = torch.tensor(images, dtype=torch.float32, device=curr_device)
im_size = images.shape[1]
# im_size = torch.tensor(im_size)
im_size

In [None]:
fig = plt.figure()
rows = 2
columns = 2
fig.add_subplot(rows, columns, 1)
plt.imshow(images[0].cpu())
plt.axis('off')
fig.add_subplot(rows, columns, 2)
plt.imshow(images[1].cpu())
plt.axis('off')
fig.add_subplot(rows, columns, 3)
plt.imshow(images[2].cpu())
plt.axis('off')
fig.add_subplot(rows, columns, 4)
plt.imshow(images[3].cpu())
plt.axis('off')

## Model: deep structural causal model

The following code models morphological transformations of MNIST,
defining a causal generative model over digits that contains endogenous
variables to control the width $t$ and intensity $i$ of the stroke:

In [None]:
class IntensityTransform(Transforms.ComposeTransformModule):
    def __init__(self, intensity_size: int, thickness_size: int, hidden_dims: List[int], weight: torch.Tensor, bias: torch.Tensor):
        self.intensity_size = intensity_size
        self.thickness_size = thickness_size
        self.hidden_dims = hidden_dims
        super().__init__([
            Transforms.ConditionalAffineAutoregressive(pyro.nn.ConditionalAutoRegressiveNN(
                intensity_size,
                thickness_size,
                hidden_dims=list(hidden_dims),
                nonlinearity=torch.nn.Identity(),
            )),
            Transforms.SigmoidTransform(),
            Transforms.AffineTransform(loc=bias, scale=weight),
        ])


class ThicknessTransform(Transforms.ComposeTransformModule):
    def __init__(self, thickness_size: int, weight: torch.Tensor, bias: torch.Tensor):
        self.thickness_size = thickness_size
        super().__init__([
            Transforms.Spline(thickness_size),
            Transforms.AffineTransform(loc=bias, scale=weight),
            Transforms.ExpTransform()
        ])
    

class PreprocessTransform(Transforms.ComposeTransformModule):
    def __init__(self, alpha: float, num_bits: int):
        self.alpha = alpha
        self.num_bits = num_bits
        super().__init__([
            Transforms.AffineTransform(0., (1. / 2 ** num_bits)),
            Transforms.AffineTransform(alpha, (1 - alpha)),
            Transforms.SigmoidTransform().inv,
        ])


class ImgAffineCouplingTransform(Transforms.AffineCoupling):
            
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        *,
        log_scale_min_clip: float = -1.,
        log_scale_max_clip: float = 5.0,
        nonlinearity: torch.nn.Module = torch.nn.LeakyReLU()  # TODO nn.functional?
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        super().__init__(
            self.input_dim // 2,
            pyro.nn.DenseNN(
                self.input_dim // 2,
                [self.hidden_dim * self.input_dim],
                [self.input_dim - self.input_dim // 2, self.input_dim - self.input_dim // 2],
                nonlinearity=nonlinearity,
            ),
            log_scale_min_clip=log_scale_min_clip,
            log_scale_max_clip=log_scale_max_clip,
        )


class ConditionalPermute(Transforms.ConditionalTransformModule):
    def __init__(self, size: int):
        self.size = size
        super().__init__(event_dim=1)

    def condition(self, context: torch.Tensor):
        return Transforms.Permute(torch.randperm(self.size, device=context.device))


class ConditionalImgTransform(ConditionalComposeTransformModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        thickness_size: int,
        intensity_size: int,
        *,
        alpha: float = 0.001,
        num_bits: int = 2,
        momentum: float = 0.05,
        log_scale_min_clip: float = -1.,
        log_scale_max_clip: float = 5.0,
    ):
        self.input_dim = input_dim 
        self.hidden_dim = hidden_dim
        self.preprocess_transform = PreprocessTransform(alpha, num_bits)
        self.f_X = Transforms.ConditionalAffineAutoregressive(
            pyro.nn.ConditionalAutoRegressiveNN(
                self.input_dim,
                thickness_size + intensity_size,
                [hidden_dim * self.input_dim],
                nonlinearity=torch.nn.Identity()
            ),
            log_scale_min_clip=log_scale_min_clip,
            log_scale_max_clip=log_scale_max_clip,
        )
        self.norm = Transforms.BatchNorm(self.input_dim, momentum=momentum)
        self.perm1 = ConditionalPermute(self.input_dim)
        self.img_affine_coupling = ImgAffineCouplingTransform(
            self.input_dim,
            hidden_dim,
            log_scale_min_clip=log_scale_min_clip,
            log_scale_max_clip=log_scale_max_clip,
        )
        self.perm2 = ConditionalPermute(self.input_dim)
        self.img_auto = ImgAffineCouplingTransform(
            self.input_dim,
            hidden_dim,
            log_scale_min_clip=log_scale_min_clip,
            log_scale_max_clip=log_scale_max_clip,
        )
        self.norm_2 = Transforms.BatchNorm(self.input_dim, momentum=momentum)
        super().__init__([
            self.preprocess_transform,
            self.f_X,
            self.norm,
            self.perm1,
            self.img_affine_coupling,
            self.perm2,
            self.img_auto,
            self.norm_2,
        ])


class ConditionalComposeTransformModule(Transforms.ConditionalTransformModule):
    def __init__(self, transforms: List[Union[Transforms.TransformModule, Transforms.ConditionalTransformModule]]):
        self.transforms = [
            Transforms.ConstantConditionalTransform(t)
            if not isinstance(t, Transforms.ConditionalTransform)
            else t
            for t in transforms
        ]
        super().__init__(event_dim=transforms[0].event_dim)

    def condition(self, context: torch.Tensor):
        return Transforms.ComposeTransformModule([t.condition(context) for t in self.transforms])


def StandardNormal(*event_shape: int, **kwargs) -> Union[dist.Independent, dist.Normal]:
    return dist.Normal(
        torch.zeros((), **kwargs),
        torch.ones((), **kwargs),
    ).expand(event_shape).to_event(len(event_shape))

In [None]:
class new_DeepSCM(PyroModule):

    def __init__(
        self,
        im_size: int,
        hidden_dim: int,
        thickness_size: int,
        intensity_size: int,
        alpha: float,
        num_bits: int,
        thickness_flow_bias: torch.Tensor,
        thickness_flow_weight: torch.Tensor,
        intensity_flow_bias: torch.Tensor,
        intensity_flow_weight: torch.Tensor,
    ):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # tensor sizes
        self.thickness_size = thickness_size
        self.intensity_size = intensity_size
        self.im_size = im_size
        self.hidden_dim = hidden_dim
        self.input_dim = self.im_size * self.im_size
  
        # Thickness parameters
        self.thickness_transform = ThicknessTransform(
            thickness_size,
            thickness_flow_weight,
            thickness_flow_bias
        )
        
        # Intensity parameters
        self.intensity_transform = IntensityTransform(
            intensity_size,
            thickness_size,
            [hidden_dim],
            intensity_flow_weight,
            intensity_flow_bias
        )
        
        # Image parameters
        self.img_transform = ConditionalImgTransform(
            self.input_dim,
            hidden_dim,
            alpha,
            num_bits,
            thickness_size,
            intensity_size,
        )

    @staticmethod
    def cond_dist(
        transform: Union[Transforms.Transform, Transforms.ConditionalTransform],
        U_dist: dist.TorchDistributionMixin,
        *contexts: torch.Tensor
    ) -> dist.Distribution:
        if not contexts:
            assert isinstance(transform, Transforms.Transform)
            return dist.TransformedDistribution(U_dist, transform)
        batch_shape = torch.broadcast_shapes(*(c.shape[:-1] for c in contexts))
        context = torch.cat([c.expand(batch_shape + (-1,)) for c in contexts], dim=-1)
        U_dist = U_dist.expand(torch.broadcast_shapes(batch_shape, U_dist.batch_shape))
        return dist.ConditionalTransformedDistribution(U_dist, transform).condition(context=context)

    def forward(self):
        # Thickness:
        UT_dist = StandardNormal(self.thickness_size, device=self.device)
        T = pyro.sample("T", self.cond_dist(self.thickness_transform, UT_dist))

        # Intensity:
        UI_dist = StandardNormal(self.intensity_size, device=self.device)
        I = pyro.sample("I", self.cond_dist(self.intensity_transform, UI_dist, T))

        # Image:
        UX_dist = StandardNormal(self.im_size ** 2, device=self.device)
        with pyro.poutine.scale(scale=1 / self.im_size ** 2):
            X = pyro.sample("X", self.cond_dist(self.img_transform, UX_dist, T, I))

        return X

In [None]:
params = {"intensity_flow_bias": intensity.min(),
"intensity_flow_weight": (intensity.max() - intensity.min()),
"thickness_flow_bias": thickness.log().mean(),
"thickness_flow_weight": thickness.log().std()}

In [None]:
class DeepSCM(PyroModule):
    def __init__(self):
        super().__init__()
        
        # Thickness parameters
        thickness_param = Transforms.Spline(1)
        thickness_param.domain = constraints.positive
        self.thickness_param = thickness_param
        
        # Intensity parameters
        intensity_net = pyro.nn.ConditionalAutoRegressiveNN(1, 1, hidden_dims=[10], nonlinearity=torch.nn.Identity())
        intensity_param = Transforms.ConditionalAffineAutoregressive(intensity_net)
        intensity_param.codomain = constraints.positive
        self.intensity_param = intensity_param
        
        # Image parameters
        input_dim = im_size*im_size
        nn_f_X = pyro.nn.ConditionalAutoRegressiveNN(input_dim, 2, [10*input_dim], nonlinearity=torch.nn.Identity())
        f_X = Transforms.ConditionalAffineAutoregressive(nn_f_X, log_scale_min_clip=-1., log_scale_max_clip=5.)
        f_X.domain = constraints.positive
        self.f_X = f_X
        norm = Transforms.BatchNorm(input_dim, momentum=0.05)
        self.norm = norm
        split_dim = input_dim // 2
        param_dims = [input_dim-split_dim, input_dim-split_dim]
#         auto_nn_0 = pyro.nn.AutoRegressiveNN(input_dim, [10*input_dim], nonlinearity=torch.nn.Identity())
#         img_affine_coupling = Transforms.AffineAutoregressive(auto_nn_0, log_scale_min_clip=-1., log_scale_max_clip=5.)
        nn_affine_coupling = pyro.nn.DenseNN(split_dim, [10*input_dim], param_dims, nonlinearity=torch.nn.LeakyReLU())
        img_affine_coupling = Transforms.AffineCoupling(split_dim, nn_affine_coupling, log_scale_min_clip=-1., log_scale_max_clip=5.0)
        self.img_affine_coupling = img_affine_coupling
        
#         auto_nn = pyro.nn.AutoRegressiveNN(input_dim, [10*input_dim], nonlinearity=torch.nn.Identity())
#         img_auto = Transforms.AffineAutoregressive(auto_nn, log_scale_min_clip=-1., log_scale_max_clip=5.)
        nn_affine_coupling_2 = pyro.nn.DenseNN(split_dim, [10*input_dim], param_dims, nonlinearity=torch.nn.LeakyReLU())
        img_auto = Transforms.AffineCoupling(split_dim, nn_affine_coupling_2, log_scale_min_clip=-1., log_scale_max_clip=5.0)
        self.img_auto = img_auto
        norm_2 = Transforms.BatchNorm(input_dim, momentum=0.05)
        self.norm_2 = norm_2
    
    def forward(self):
        # Thickness:
        UT = dist.Normal(torch.tensor(0., device=curr_device), torch.tensor(1., device=curr_device)).expand([1]).to_event(1)
        thickness_flow_loc = params["thickness_flow_bias"]
        thickness_flow_scale = params["thickness_flow_weight"]
        thickness_flow_lognorm = Transforms.AffineTransform(loc=thickness_flow_loc, scale=thickness_flow_scale)
        t_transforms = [
            self.thickness_param,
            thickness_flow_lognorm,
            Transforms.ExpTransform()
        ]
        T = pyro.sample("T", dist.TransformedDistribution(UT, t_transforms))
        
        # Intensity:
        UI = dist.Normal(torch.tensor(0., device=curr_device), torch.tensor(1., device=curr_device)).expand([1]).to_event(1)
        intensity_flow_loc = params["intensity_flow_bias"]
        intensity_flow_scale = params["intensity_flow_weight"]
        intensity_flow_norm = Transforms.AffineTransform(loc=intensity_flow_loc, scale=intensity_flow_scale)
        intensity_tranforms = [
            self.intensity_param,
            Transforms.SigmoidTransform(), 
            intensity_flow_norm
        ]
#         T = T.expand(torch.broadcast_shapes(T.shape[:-1]) + T.shape[-1:])
        I_ = dist.ConditionalTransformedDistribution(UI, intensity_tranforms)
        I = I_.condition(context=T)
        I = pyro.sample("I", I)

        
        # Image:
        UX = dist.Normal(torch.tensor(0., device=curr_device), torch.tensor(1., device=curr_device)).expand([im_size*im_size]).to_event(1)
        
        # Preprocessing
        alpha = 0.001
        num_bits = 2
        s = Transforms.SigmoidTransform()
        preprocess_transform = Transforms.ComposeTransform([
            Transforms.AffineTransform(0., (1. / 2 ** num_bits)),
            Transforms.AffineTransform(alpha, (1 - alpha)),
            s.inv
        ])
    
        batch_shape = torch.broadcast_shapes(T.shape[:-1], I.shape[:-1])
        T = T.expand(batch_shape + T.shape[-1:])
        I = I.expand(batch_shape + I.shape[-1:])
        
        assert T.shape == I.shape
        
        f_X = self.f_X.condition(context=torch.cat((T, I), dim=-1))
        
#         assert torch.cat((T, I), dim=-1).shape == (2, )
        perm1 = Transforms.Permute(torch.randperm(im_size*im_size, device=curr_device))
        perm2 = Transforms.Permute(torch.randperm(im_size*im_size, device=curr_device))
        
        h_X = dist.TransformedDistribution(UX, [preprocess_transform,
                                                f_X,
                                                self.norm,
                                                perm1,
                                                self.img_affine_coupling,
                                                perm2,
                                                self.img_auto,
                                                self.norm_2
                                               ])
        with pyro.poutine.scale(scale=1/(im_size*im_size)):
            X = pyro.sample("X", h_X)
        return X

In [None]:
scm = DeepSCM().to(device=curr_device)
print(list(dict(scm.named_parameters()).keys()))
# print(pyro.poutine.trace(scm).get_trace().log_prob_sum())
pyro.render_model(scm)

In [None]:
plt.imshow(torch.nn.functional.normalize(scm().cpu().detach().reshape((im_size, im_size))))

In [None]:
intervened_scm = do(scm, {"I": torch.randn(1, device = curr_device)})
pyro.render_model(intervened_scm)

In [None]:
def conditioned_scm(model):
    def query_model(t_obs, i_obs, x_obs):
        with pyro.condition(data={"X": x_obs, "T": t_obs, "I": i_obs}), \
                pyro.plate("data", size=x_obs.shape[0], dim=-1):
            return model()
    return query_model

scm = DeepSCM().to(device=curr_device)
conditioned_model = conditioned_scm(scm)
imgs = conditioned_model(thickness[:3, None], intensity[:3, None], images[:3].reshape(-1, im_size*im_size))
pyro.render_model(conditioned_model, model_args=(thickness[:2][..., None], intensity[:2][..., None], images[:2].reshape(-1, im_size*im_size)))

In [None]:
initial_lr = 0.0001
num_iterations = 2
adam_params = {"lr": initial_lr, "betas": (0.95, 0.999)}
optimizer = pyro.optim.Adam(adam_params)
empty_guide = lambda *args: None
batch_size = 256

In [None]:
# import math
# n = math.ceil(len(dataset/batch_size))
scaled_model = pyro.poutine.scale(conditioned_model, scale=1/(batch_size))
svi = SVI(scaled_model, empty_guide, optimizer, loss=pyro.infer.Trace_ELBO())
dataset = [(thickness[i], intensity[i], images[i]) for i in range(images.shape[0])]
n = 2

In [None]:
pyro.clear_param_store()
loss =[]
for j in range(num_iterations):
    data = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True))
    for i in range(n):
        t_obs, i_obs, x_obs = next(data)
        loss.append(svi.step(t_obs[..., None], i_obs[..., None], x_obs.reshape(-1, im_size*im_size)))
    if j%100 == 0:
        print(sum(loss[-n:])/n)

In [None]:
plt.plot(loss)

In [None]:
predictive = pyro.infer.Predictive(condition(scm, data = {"T": thickness[0][..., None], "I": intensity[0][..., None]}), guide=empty_guide, num_samples=4)
img = predictive()["X"]

In [None]:
fig = plt.figure()
rows = 2
columns = 2
fig.add_subplot(rows, columns, 1)
plt.imshow(torch.nn.functional.normalize(img[0].cpu().reshape((im_size, im_size))))
plt.axis('off')
fig.add_subplot(rows, columns, 2)
plt.imshow(torch.nn.functional.normalize(img[1].cpu().reshape((im_size, im_size))))
plt.axis('off')
fig.add_subplot(rows, columns, 3)
plt.imshow(torch.nn.functional.normalize(img[2].cpu().reshape((im_size, im_size))))
plt.axis('off')
fig.add_subplot(rows, columns, 4)
plt.imshow(torch.nn.functional.normalize(img[3].cpu().reshape((im_size, im_size))))
plt.axis('off')

## Query: counterfactual data generation

Next we ask a *counterfactual* question: given an observed digit $X$, what
would the digit have been had $t$ been $t + 1$?

To compute this quantity we would normally:
   1. invert the model to find latent exogenous noise $u$
   2. construct an intervened model
   3. re-simulate the forward model on the $u$ [@pearl2011algorithmization].  

However, we can equivalently
represent this process with inference in a single, expanded
probabilistic program containing two copies of every deterministic
statement (a so-called \"twin network\" representation of
counterfactuals, first described in Chapter 7 of [@pearl] and extended
to the PPL setting in [@tavares_2020])

In [None]:
x_obs = images[0]
plt.imshow(x_obs.cpu().detach().reshape((im_size, im_size)))

In [None]:
def deep_scm_query(model: DeepSCM):
    def query_model(x_obs):
        with MultiWorldCounterfactual(dim=-1), \
            do(actions={'I': torch.tensor([190.0], device=model.device)}), \
                condition(data={"X": x_obs.reshape(-1, model.input_dim).to(device=model.device)}):
                    return model()
    return query_model

cf_model = pyro.poutine.reparam(config={"X_observed": TransformInferReparam()})(
    deep_scm_query(
        scm
    )
)

In [None]:
%pdb on
fig = plt.figure()
plt.title("Twin World Counterfactual Model")
plt.axis("off")
plt.tight_layout()
rows = 1
columns = 2
fig.add_subplot(rows, columns, 1)
plt.imshow(cf_model(x_obs)["X"][0][0].cpu().reshape((14, 14)))
plt.title("Actual Model")
plt.axis('off')
fig.add_subplot(rows, columns, 2)
plt.imshow(torch.nn.functional.normalize(cf_model(x_obs)["X"][0][1].cpu().reshape((14, 14))))
plt.title("Intervened Model")
plt.axis('off')

Like all counterfactuals, this estimand is not identified in general
without further assumptions: learning parameters $\theta$ that match
observed data does not guarantee that the counterfactual distribution
will match that of the true causal model. 

However, as discussed in the
original paper [@pawlowski2020deep] in the context of modeling MRI
images, there are a number of valid practical reasons one might wish to
compute it anyway, such as explanation or expert evaluation.