In [None]:
# default_exp cf_explainer

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from counternet.import_essentials import *
from counternet.utils import CategoricalNormalizer, flip_binary, load_configs, l1_mean, hinge_loss
from counternet.base_interface import ExplainerBase, LocalExplainerBase, GlobalExplainerBase
from counternet.training_module import BaseModule
from counternet.model import MultilayerPerception

Global seed set to 31


In [None]:
# export
class Clamp(torch.autograd.Function):
    """
    Clamp parameter to [0, 1]
    code from: https://discuss.pytorch.org/t/regarding-clamped-learnable-parameter/58474/4
    """
    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0, max=1)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

In [None]:
# export
class VanillaCF(LocalExplainerBase):
    def __init__(self, pred_fn: Callable[[torch.Tensor], torch.Tensor],
                       cat_normalizer: CategoricalNormalizer=None, configs: Dict[str, Any]={}):
        """vanilla version of counterfactual generation
            - link: https://doi.org/10.2139/ssrn.3063289
        Args:
        """
        super().__init__(pred_fn, configs)
        self.steps = configs.steps if 'steps' in configs else 1000
        self.cat_normalizer = cat_normalizer
        self.cf = None

    def forward(self):
        if self.cf is None:
            raise NotImplementedError('cf has not been initialized...')
        cf = self.cf * 1.0
        return cf if self.cat_normalizer is None else self.cat_normalizer.normalize(cf)

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def _loss_functions(self, x, c):
        # target
        y_pred = self.pred_fn(x)
        y_prime = flip_binary(y_pred)

        c_y = self.pred_fn(c)
        l_1 = F.binary_cross_entropy(c_y, y_prime)
        l_2 = F.mse_loss(x, c)
        return l_1, l_2

    def _loss_compute(self, l_1, l_2):
        return 1.0 * l_1 + 0.5 * l_2

    def generate_cf(self, x: torch.Tensor):
        self.cf = nn.Parameter(x.clone(), requires_grad=True)
        optim = self.configure_optimizers()
        for i in range(self.steps):
            c = self()
            l_1, l_2 = self._loss_functions(x, c)
            loss = self._loss_compute(l_1, l_2)
            optim.zero_grad()
            loss.backward()
            optim.step()

        cf = self.cf.clone().detach() * 1.0
        return cf if self.cat_normalizer is None else self.cat_normalizer.normalize(cf, hard=True)

In [None]:
# export
class DiverseCF(LocalExplainerBase):
    def __init__(self, pred_fn: Callable[[torch.Tensor], torch.Tensor],
                       cat_normalizer: CategoricalNormalizer=None, configs: Dict[str, Any]={}):
        """diverse counterfactual explanation
            - link: https://doi.org/10.1145/3351095.3372850
        Args:
        """
        super().__init__(pred_fn, configs)
        self.steps = configs.steps if 'steps' in configs else 1000
        self.cat_normalizer = cat_normalizer
        self.n_cfs = configs.n_cfs if 'n_cfs' in configs else 5
        self.cf = None

    def forward(self):
        cf = self.cf * 1.0
        return cf

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def _compute_dist(self, x1, x2):
        return torch.sum(torch.abs(x1 - x2), dim = 0)

    def _compute_proximity_loss(self):
        """Compute the second part (distance from x1) of the loss function."""
        proximity_loss = 0.0
        for i in range(self.n_cfs):
            proximity_loss += self.compute_dist(self.cf[i], self.x1)
        return proximity_loss/(torch.mul(len(self.minx[0]), self.total_CFs))

    def _dpp_style(self, cf):
        det_entries = torch.ones(self.n_cfs, self.n_cfs)
        for i in range(self.n_cfs):
            for j in range(self.n_cfs):
                det_entries[i, j] = self._compute_dist(cf[i], cf[j])

        # implement inverse distance
        det_entries = 1.0 / (1.0 + det_entries)
        det_entries += torch.eye(self.n_cfs) * 0.0001
        return torch.det(det_entries)

    def _compute_diverse_loss(self, cf):
        return self._dpp_style(cf)

    def _compute_regularization_loss(self):
        if self.cat_normalizer is None:
            return 0.
        
        cat_idx = self.cat_normalizer.cat_idx
        regularization_loss = 0.
        for i in range(self.n_cfs):
            for col in self.cat_normalizer.categories:
                cat_idx_end = cat_idx + len(col)
                regularization_loss += torch.pow((torch.sum(self.cf[i][cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def _loss_functions(self, x, c):
        # target
        y_pred = self.pred_fn(x)
        y_prime = torch.ones(y_pred.shape) - y_pred

        c_y = self.pred_fn(c)
        # yloss
        l_1 = hinge_loss(input=c_y, target=y_prime.float())
        # proximity loss
        l_2 = l1_mean(x, c)
        # diverse loss
        l_3 = self._compute_diverse_loss(c)
        # categorical penalty
        l_4 = self._compute_regularization_loss()
        return l_1, l_2, l_3, l_4

    def _compute_loss(self, *loss_f):
        return sum(loss_f)

    def generate_cf(self, x: torch.Tensor):
        self.cf = nn.Parameter(torch.rand(self.n_cfs, x.size(1)), requires_grad=True) 
        optim = self.configure_optimizers()
        for i in range(self.steps):
            c = self()

            l_1, l_2, l_3, l_4 = self._loss_functions(x, c)
            loss = self._compute_loss(l_1, l_2, l_3, l_4)
            optim.zero_grad()
            loss.backward()
            optim.step()

        cf = self.cf.clone().detach() * 1.0
        cf = torch.clamp(cf, 0, 1)
        # return cf[0]
        return cf if self.cat_normalizer is None else self.cat_normalizer.normalize(cf, hard=True)

In [None]:
# exporti
class AE(BaseModule):
    def __init__(self, configs, encoded_size=5):
        super().__init__(configs)
        input_dim = configs['encoder_dims'][0]
        self.encoder_model = MultilayerPerception([input_dim, 20, 16, 14, 12, encoded_size])
        self.decoder_model = MultilayerPerception([encoded_size, 12, 14, 16, 20, input_dim])

    def forward(self, x):
        z = self.encoded(x)
        x_prime = self.decoder_model(z)
        return x_prime

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def encoded(self, x):
        return self.encoder_model(x)

    def training_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        x_prime = self(x)

        loss = F.mse_loss(x_prime, x, reduction='mean')

        self.log('train/loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        x_prime = self(x)

        loss = F.mse_loss(x_prime, x, reduction='mean')

        self.log('val/val_loss', loss)

        return loss

In [None]:
# exporti
class VAE(pl.LightningModule):
    def __init__(self, input_dims, encoded_size=5):
        super().__init__()
        self.encoder_mean = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])
        self.encoder_var = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])
        self.decoder_mean = MultilayerPerception([encoded_size + 1, 12, 14, 16, 20, input_dims])

    def encoder(self, x):
        mean = self.encoder_mean(x)
        logvar = 0.5+ self.encoder_var(x)
        return mean, logvar

    def decoder(self, z):
        mean = self.decoder_mean(z)
        return mean

    def sample_latent_code(self, mean, logvar):
        eps = torch.randn_like(logvar)
        return mean + torch.sqrt(logvar) * eps

    def normal_likelihood(self, x, mean, logvar, raxis=1):
        return torch.sum( -.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar) ), axis=1)

    def forward(self, x, c):
        """
        x: input instance
        c: target y
        """
        c = c.view(c.shape[0], 1)
        c = c.clone().detach().float
        res = {}
        mc_samples = 50
        em, ev = self.encoder(torch.cat((x, c), 1))
        res['em'] = em
        res['ev'] = ev
        res['z'] = []
        res['x_pred'] = []
        res['mc_samples'] = mc_samples
        for i in range(mc_samples):
            z = self.sample_latent_code(em, ev)
            x_pred = self.decoder(torch.cat((z, c), 1))
            res['z'].append(z)
            res['x_pred'].append(x_pred)
        return res

    def compute_elbo(self, x, c, model):
        c= c.clone().detach().float()
        c=c.view(c.shape[0], 1)
        em, ev = self.encoder(torch.cat((x,c),1))
        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)

        z = self.sample_latent_code(em, ev)
        dm= self.decoder( torch.cat((z,c),1) )
        log_px_z = torch.tensor(0.0)

        x_pred= dm
        return torch.mean(log_px_z), torch.mean(kl_divergence), x, x_pred, model.predict(x_pred)

In [None]:
# export
class VAE_CF(BaseModule, GlobalExplainerBase):
    def __init__(self, config: Dict, pred_model: pl.LightningModule):
        """
        config: basic configs
        model: the black-box model to be explained
        """
        super().__init__(config)
        self.pred_model = pred_model
        self.pred_model.freeze()
        self.vae = VAE(input_dims=self.enc_dims[0])
        # validity_reg set to 42.0
        # according to https://interpret.ml/DiCE/notebooks/DiCE_getting_started_feasible.html#Generate-counterfactuals-using-a-VAE-model
        self.validity_reg = config['validity_reg'] if 'validity_reg' in config.keys() else 1.0

    def model_forward(self, x):
        """lazy implementation since this method will not be used"""
        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1 - self.pred_model.predict(x), self.pred_model)
        # return y, c
        return cf_label, x_pred

    def forward(self, x):
        """lazy implementation since this method will not be used"""
        return self.model_forward(x)

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def predict(self, x):
        return self.pred_model.predict(x)

    def compute_loss(self, out, x, y):
        em = out['em']
        ev = out['ev']
        z = out['z']
        dm = out['x_pred']
        mc_samples = out['mc_samples']
        #KL Divergence
        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)

        #Reconstruction Term
        #Proximity: L1 Loss
        x_pred = dm[0]
        cat_idx = len(self.continous_cols)
        # recon_err = - \
        #     torch.sum(torch.abs(x[:, cat_idx:-1] -
        #                         x_pred[:, cat_idx:-1]), axis=1)
        recon_err = - torch.sum(torch.abs(x - x_pred), axis=1)

        # Sum to 1 over the categorical indexes of a feature
        for col in self.cat_normalizer.categories:
            cat_end_idx = cat_idx + len(col)
            temp = - \
                torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))
            recon_err += temp

        #Validity
        c_y = self.pred_model(x_pred)
        validity_loss = torch.zeros(1, device=self.device)
        validity_loss += hinge_loss(input=c_y, target=y.float())

        for i in range(1, mc_samples):
            x_pred = dm[i]

            # recon_err += - \
            #     torch.sum(torch.abs(x[:, cat_idx:-1] -
            #                         x_pred[:, cat_idx:-1]), axis=1)
            recon_err += - torch.sum(torch.abs(x - x_pred), axis=1)

            # Sum to 1 over the categorical indexes of a feature
            for col in self.cat_normalizer.categories:
                cat_end_idx = cat_idx + len(col)
                temp = - \
                    torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))
                recon_err += temp

            #Validity
            c_y = self.pred_model(x_pred)
            validity_loss += hinge_loss(c_y, y.float())

        recon_err = recon_err / mc_samples
        validity_loss = -1 * self.validity_reg * validity_loss / mc_samples

        return -torch.mean(recon_err - kl_divergence) - validity_loss

    def training_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        y_hat = self.pred_model.predict(x)
        # target
        y = 1.0 - y_hat

        out = self.vae(x, y)
        loss = self.compute_loss(out, x, y)

        self.log('train/loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        y_hat = self.pred_model.predict(x)
        # target
        y = 1.0 - y_hat

        out = self.vae(x, y)
        loss = self.compute_loss(out, x, y)

        _, _, _, x_pred, cf_label = self.vae.compute_elbo(x, y, self.pred_model)

        cf_proximity = torch.abs(x - x_pred).sum(dim=1).mean()
        cf_accuracy = accuracy(cf_label, y.int())

        self.log('val/val_loss', loss)
        self.log('val/proximity', cf_proximity)
        self.log('val/cf_accuracy', cf_accuracy)

    def generate_cf(self, x):
        self.vae.freeze()
        y_hat = self.pred_model.predict(x)
        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1.-y_hat, self.pred_model)
        return self.pred_model.cat_normalizer.normalize(x_pred, hard=True)