In [None]:
# default_exp training_module

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from counternet.import_essentials import *
from counternet.utils.all import *
from counternet.evaluation import SensitivityMetric, ProximityMetric


pl_logger = logging.getLogger('lightning')

Global seed set to 31


In [None]:
print(f"pl version: {pl.__version__}")
print(f"torch version: {torch.__version__}")

pl version: 1.3.5
torch version: 1.8.0


In [None]:
%%time
dummy_data = pd.read_csv('assets/data/dummy_data.csv')
adult_data = load_adult_income_dataset('assets/data/adult.data')

CPU times: user 774 ms, sys: 29.5 ms, total: 803 ms
Wall time: 802 ms


## Utils

### Normalize the categorical elements 

In [None]:
# export utils.processing
class CategoricalNormalizer(object):
    """implement post-processing step to enforce each elements 
    in every category in the range of [0, 1] and output to 1.
    """
    def __init__(self, categories: List[List[Any]], cat_idx: int):
        self.categories = categories
        self.cat_idx = cat_idx

    def normalize(self, x, hard=False):
        cat_idx = self.cat_idx
        for col in self.categories:
            cat_end_idx = cat_idx + len(col)
            if hard:
                x[:, cat_idx: cat_end_idx] = F.gumbel_softmax(x[:, cat_idx: cat_end_idx].clone().detach(), hard=hard)
            else:
                x[:, cat_idx: cat_end_idx] = F.softmax(x[:, cat_idx: cat_end_idx].clone().detach(), dim=-1)
            cat_idx = cat_end_idx
        return x

### Define Metrics

In [None]:
# export evaluation
class SensitivityMetric(Metric):
    def __init__(self, predict_fn: Callable, scaler: ABCScaler, cat_idx: int, threshold: float):
        super().__init__(dist_sync_on_step=False)
        self.predict_fn = predict_fn
        self.scaler = scaler
        self.cat_idx = cat_idx
        self.threshold = threshold

        self.add_state("total_n_changes", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("diffs", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, x: torch.Tensor, c: torch.Tensor, c_y: torch.Tensor):
        # inverse transform
        x_cont_inv = self.scaler.inverse_transform(x[:, :self.cat_idx])
        c_cont_inv = self.scaler.inverse_transform(c[:, :self.cat_idx])
        # a bool metrics on whether differences between x and c is smaller than the threshold
        cont_diff = torch.abs(x_cont_inv - c_cont_inv) < self.threshold
        # total nums of differences
        self.total_n_changes += torch.sum(cont_diff.any(axis=1))
        # new continous cf
        c_cont_hat = torch.where(cont_diff, x_cont_inv, c_cont_inv)
        c[:, :self.cat_idx] = self.scaler.transform(c_cont_hat)
        c_y_hat = self.predict_fn(c)

        self.diffs += (torch.round(c_y) != torch.round(c_y_hat)).sum()

    def compute(self):
        return 1 - self.diffs / self.total_n_changes

In [None]:
x = torch.rand((10, 4)) 
c = deepcopy(x)
c[:5, :] = c[:5, :] + torch.rand((5, 4))
c[5:, :] = c[5:, :] + torch.tensor([1.1, -2.1, 1.01, -1.2])

pred_func = lambda arr: torch.mean(arr, dim=1) * 10

scaler = StandardScaler().fit(x)
c_y = pred_func(scaler.transform(c))


sensitivity = SensitivityMetric(predict_fn=pred_func, scaler=scaler, cat_idx=4, threshold=1.)
sensitivity.update(scaler.transform(x), scaler.transform(c), c_y)
score = sensitivity.compute()
diffs = sensitivity.diffs
total_n_changes = sensitivity.total_n_changes

assert torch.equal(score, torch.tensor(0.))
assert torch.equal(diffs, torch.tensor(5))
assert torch.equal(total_n_changes, torch.tensor(5))

In [None]:
x = torch.rand((10, 4)) 
c = x + torch.tensor([1.1, 0, 0, -1.1])
c[:, 1:3] = c[:, 1:3] + torch.rand((10, 2))

pred_func = lambda x: torch.mean(x, dim=1) * 10
c_y = pred_func(scaler.transform(c))
scaler = StandardScaler().fit(x)

sensitivity = SensitivityMetric(predict_fn=pred_func, scaler=scaler, cat_idx=4, threshold=1.)
sensitivity.update(scaler.transform(x), scaler.transform(c), c_y)
score = sensitivity.compute()
diffs = sensitivity.diffs
total_n_changes = sensitivity.total_n_changes

assert torch.equal(score, torch.tensor(0.))
assert torch.equal(diffs, torch.tensor(10))
assert torch.equal(total_n_changes, torch.tensor(10))

In [None]:
# export evaluation
def proximity(x:torch.Tensor, c: torch.Tensor):
    return torch.abs(x - c).sum(dim=-1).mean()

In [None]:
x = torch.tensor([1, 2, 1])
c = torch.tensor([-1, 1., 0.1])
assert proximity(x, c) == torch.tensor(3.9)

x_ = torch.tensor([1.5, 2.5, 1])
c_ = torch.tensor([-0.5, 1.5, 0.1])
assert proximity(x, c) == torch.tensor(3.9)


In [None]:
x = torch.tensor([[1, 2, 1], [-1, 1., 0.1]])
c = torch.tensor([[-1, 1., 0.1], [1, 2, 1]])
assert proximity(x, c) == torch.tensor(3.9)

In [None]:
# export evaluation
class ProximityMetric(Metric):
    def __init__(self):
        super().__init__(dist_sync_on_step=False)
        self.add_state("dist", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("n", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, x, c):
        self.dist += proximity(x, c)
        self.n += 1

    def compute(self):
        if self.n == 0:
            return -1
        else:
            return self.dist / self.n

In [None]:
metric = ProximityMetric()
x = torch.tensor([1, 2, 1])
c = torch.tensor([-1, 1., 0.1])

metric.update(x, c)
assert metric.compute() == torch.tensor(3.9)

x_ = torch.tensor([1.5, 2.5, 1])
c_ = torch.tensor([-0.5, 1.5, 0.1])
metric.update(x_, c_)
assert metric.compute() == torch.tensor(3.9)


In [None]:
metric = ProximityMetric()
x = torch.tensor([[1, 2, 1], [-1, 1., 0.1]])
c = torch.tensor([[-1, 1., 0.1], [1, 2, 1]])
metric.update(x, c)

assert metric.compute() == torch.tensor(3.9)

### Define other utility functions for training

In [None]:
# export utils.functional
def l1_mean(x, c):
    return F.l1_loss(x, c, reduction='mean') / x.abs().mean() # MAD

def get_loss_functions(f_name: str):
    _loss_functions = {
        'cross_entropy': F.binary_cross_entropy,
        'l1': F.l1_loss,
        'l1_mean': l1_mean,
        'mse': F.mse_loss
    }

    assert f_name in _loss_functions.keys(), \
        f'function name `{f_name}` is not in the loss function list {_loss_functions.keys()}'

    return _loss_functions[f_name]


In [None]:
# export utils.functional
def split_X_y(data: pd.DataFrame):
    X = data[data.columns[:-1]]
    y = data[data.columns[-1]]
    return X, y

def train_val_test_split(X, y):
    assert len(X) == len(y)
    size = len(X)
    train_size = int(0.7 * size)    # 70% for training
    val_size = int(0.8 * size)      # 10% for validation

    return tuple(
        tuple(X[: train_size], y[: train_size]),
        tuple(X[train_size:val_size], y[train_size:val_size]),
        tuple(X[val_size:], y[val_size:])
    )

## Base Module

In [None]:
# export
class ABCBaseModule(ABC):
    @abstractmethod
    def model_forward(self, *x):
        raise NotImplementedError

    @abstractmethod
    def forward(self, *x):
        raise NotImplementedError

    @abstractmethod
    def predict(self, *x):
        raise NotImplementedError

In [None]:
# export
class BaseModule(pl.LightningModule, ABCBaseModule):
    def __init__(self, configs: Dict[str, Any]):
        super().__init__()
        self.save_hyperparameters(configs)

        # read data
        self.data = pd.read_csv(Path(configs['data_dir']))
        self.continous_cols = configs['continous_cols']
        self.discret_cols = configs['discret_cols']
        self.__check_cols()

        # set training configs
        self.lr = configs['lr']
        self.batch_size = configs['batch_size']
        self.dropout = configs['dropout'] if 'dropout' in configs.keys() else 0.3
        self.lambda_1 = configs['lambda_1'] if 'lambda_1' in configs.keys() else 1
        self.lambda_2 = configs['lambda_2'] if 'lambda_2' in configs.keys() else 1
        self.lambda_3 = configs['lambda_3'] if 'lambda_3' in configs.keys() else 1
        self.threshold = configs['threshold'] if 'threshold' in configs.keys() else 1
        self.smooth_y = configs['smooth_y'] if 'smooth_y' in configs.keys() else True

        # loss functions
        self.loss_func_1 = get_loss_functions(configs['loss_1']) if 'loss_1' in configs.keys() else get_loss_functions("mse")
        self.loss_func_2 = get_loss_functions(configs['loss_2']) if 'loss_2' in configs.keys() else get_loss_functions("mse")
        self.loss_func_3 = get_loss_functions(configs['loss_3']) if 'loss_3' in configs.keys() else get_loss_functions("mse")

        # set model configss
        self.enc_dims = configs['encoder_dims'] if 'encoder_dims' in configs.keys() else []
        self.dec_dims = configs['decoder_dims'] if 'decoder_dims' in configs.keys() else []
        self.exp_dims = configs['explainer_dims'] if 'explainer_dims' in configs.keys() else []

        # log graph
        self.example_input_array = torch.randn((1, self.enc_dims[0]))

    def __check_cols(self):
        assert sorted(list(self.data.columns)) == sorted(self.continous_cols + self.discret_cols)
        self.data = self.data.astype({col: np.float for col in self.continous_cols})

    def training_epoch_end(self, outs):
        if self.current_epoch == 0:
            self.logger.log_hyperparams(self.hparams)

    def prepare_data(self):
        # TODO Decouple data preparision and use `LightningDataModule`
        # 70% for training, 10% for validation, 20% for testing
        X, y = split_X_y(self.data)

        # preprocessing
        self.scaler = MinMaxScaler()
        self.ohe = OneHotEncoder()
        X_cont = self.scaler.fit_transform(X[self.continous_cols]) if self.continous_cols else np.array([[] for _ in range(len(X))])
        X_cat = self.ohe.fit_transform(X[self.discret_cols]) if self.discret_cols else np.array([[] for _ in range(len(X))])
        X = torch.cat((X_cont, X_cat), dim=1)

        # init categorical normalizer to enable categorical features to be one-hot-encoding format
        cat_arrays = self.ohe.categories_ if self.discret_cols else []
        self.cat_normalizer = CategoricalNormalizer(cat_arrays, cat_idx=len(X_cont))

        # init sensitivity metric
        self.sensitivity = SensitivityMetric(
            predict_fn=self.predict, scaler=self.scaler, cat_idx=len(X_cont), threshold=self.threshold)

        pl_logger.info(f"x_cont: {X_cont.size()}, x_cat: {X_cat.size()}")
        pl_logger.info("X shape: ", X.size())

        assert X.size(-1) == self.enc_dims[0],\
            f'The input dimension X (shape: {X.shape[-1]})  != encoder_dims[0]: {self.enc_dims}'

        # prepare train & test
        train, val, test = train_val_test_split(X, y.to_numpy())
        self.train_dataset = NumpyDataset(*train)
        self.val_dataset = NumpyDataset(*val)
        self.test_dataset = NumpyDataset(*test)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          pin_memory=True, shuffle=True, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          pin_memory=True, shuffle=False, num_workers=0)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                          pin_memory=True, shuffle=False, num_workers=0)

## Predictive Module

In [None]:
# export utils.functional
def uniform(shape: tuple, r1: float, r2: float, device=None):
    assert r1 < r2, f"Issue: r1 ({r1}) >= r2 ({r2})"
    return (r2 - r1) * torch.rand(*shape, device=device) + r1


def smooth_y(y, device=None):
    return torch.where(y == 1,
        uniform(y.size(), 0.8, 0.95, device=y.device),
        uniform(y.size(), 0.05, 0.2, device=y.device))

In [None]:
# export
class PredictiveTrainingModule(BaseModule):
    def __init__(self, configs: Dict[str, Any]):
        super().__init__(configs)
        # define metrics
        self.val_acc = Accuracy()

    def forward(self, *x):
        return self.model_forward(x)

    def predict(self, x):
        self.freeze()
        y_hat = self(x)
        return torch.round(y_hat)

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def training_step(self, batch, batch_idx):
        # batch
        *x, y = batch
        # fwd
        y_hat = self(*x)
        # loss
        if self.smooth_y:
            y = smooth_y(y)
        loss = F.binary_cross_entropy(y_hat, y)

        # Logging to TensorBoard
        self.log('train/train_loss_1', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        *x, y = batch
        # fwd
        y_hat = self(*x)
        # loss
        loss = F.binary_cross_entropy(y_hat, y)
        self.accuracy(y_hat, y)
        self.log('val/val_loss_1', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log('val/pred_accuracy', self.accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)

## CounterNet Module

In [None]:
# export
class CFNetTrainingModule(BaseModule):
    def __init__(self, configs: Dict[str, Any]):
        super().__init__(configs)
        # define metrics
        self.pred_acc = Accuracy()
        self.cf_acc = Accuracy()
        self.proximity = ProximityMetric()

    def forward(self, x, hard: bool=False):
        """hard: categorical features in counterfactual is one-hot-encoding or not"""
        y, c = self.model_forward(x)
        c = self.cat_normalize(c, hard=hard)
        return y, c

    def predict(self, x):
        """x has not been preprocessed"""
        self.freeze()
        y_hat, c = self.model_forward(x)
        return torch.round(y_hat)

    def generate_cf(self, x, clamp=False):
        self.freeze()
        y, c = self.model_forward(x)
        if clamp:
            c = torch.clamp(c, 0., 1.)
        return self.cat_normalizer.normalize(c, hard=True)

    def _logging_loss(self, *loss, stage: str, on_step: bool = False):
        for i, l in enumerate(loss):
            self.log(f'{stage}/{stage}_loss_{i+1}', l, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)

    def _loss_functions(self, x, c, y, y_hat, y_prime=None, is_val=False):
        """
        x: input value
        c: conterfactual example
        y: ground truth
        y_hat: predicted result
        y_prime_mode: 'label' or 'predicted'
        """
        # flip zero/one
        if y_prime == None:
            y_prime = (y_hat < .5).clone().detach().float()

        c_y, _ = self(c)
        # loss functions
        if self.smooth_y and not is_val:
            y = smooth_y(y)
            y_prime = smooth_y(y_prime)
        l_1 = self.loss_func_1(y_hat, y)
        l_2 = self.loss_func_2(x, c)
        l_3 = self.loss_func_3(c_y, y_prime)

        return l_1, l_2, l_3

    def configure_optimizers(self):
        opt_1 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        opt_2 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        return (opt_1, opt_2)

    def predictor_step(self, l_1, l_3):
        p_loss = self.lambda_1 * l_1 # + self.lambda_3 * l_3
        self.log('train/p_loss', p_loss, on_step=False, on_epoch=True, sync_dist=True)
        return p_loss

    def explainer_step(self, l_2, l_3):
        e_loss = self.lambda_2 * l_2 + self.lambda_3 * l_3
        self.log('train/e_loss', e_loss, on_step=False, on_epoch=True, sync_dist=True)
        return e_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        # batch
        x, y = batch
        # fwd
        y_hat, c = self(x)
        # loss
        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)

        result = 0
        if optimizer_idx == 0:
            result = self.predictor_step(l_1, l_3)

        if optimizer_idx == 1:
            result = self.explainer_step(l_2, l_3)

        # Logging to TensorBoard by default
        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=False)
        return result

    def validation_step(self, batch, batch_idx):
        # batch
        x, y = batch

        # fwd
        y_hat, c = self(x, hard=True)
        c_y, _ = self(c)

        # loss
        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat, is_val=True)
        loss = self.predict_step(l_1, l_3) + self.explainer_step(l_2, l_3)

        # logging val loss
        self._logging_loss(l_1, l_2, l_3, stage='val', on_step=False)

        # metrics
        metrics = {
            'val/val_loss': loss, 'val/pred_accuracy': self.pred_acc(round(y_hat), y),
            'val/cf_proximity': self.proximity(x, c), 'val/cf_accuracy': self.cf_acc(round(c_y), 1 - round(y_hat)),
            'val/sensitivity': self.sensitivity(x, c, c_y),
        }
        self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)

## Model


In [None]:
# export model
class LinearBlock(nn.Module):
    def __init__(self, input_dim, out_dim, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, out_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.block(x)

class MultilayerPerception(nn.Module):
    def __init__(self, dims=[3, 100, 10], dropout=0.3):
        super().__init__()
        layers  = []
        num_blocks = len(dims)
        for i in range(1, num_blocks):
            layers += [
                LinearBlock(dims[i-1], dims[i], dropout=dropout)
            ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


In [None]:
# export model
class BaselinePredictiveModel(PredictiveTrainingModule):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], \
            f"(enc_dims[-1]={self.enc_dims[-1]}) != (dec_dims[0]={self.dec_dims[0]})"
        self.model = nn.Sequential(
            MultilayerPerception(self.enc_dims, self.dropout),
            MultilayerPerception(self.dec_dims, self.dropout),
            nn.Linear(self.dec_dims[-1], 1)
        )

    def model_forward(self, x):
        # x = ([],)
        x, = x
        y_hat = torch.sigmoid(self.model(x))
        return torch.squeeze(y_hat, -1)

In [None]:
# export model
class CounterNetModel(CFNetTrainingModule):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], \
            f"(enc_dims[-1]={self.enc_dims[-1]}) != (dec_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], \
            f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.encoder_model = MultilayerPerception(self.enc_dims)
        # predictor
        self.predictor = MultilayerPerception(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        exp_dims = [x for x in self.exp_dims]
        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]

        self.explainer = nn.Sequential(
            MultilayerPerception(exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        x = self.encoder_model(x)
        # predicted y_hat
        pred = self.predictor(x)
        y_hat = torch.sigmoid(self.pred_linear(pred))
        # counterfactual example
        x = torch.cat((x, pred), -1)
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c