# PyTorch Ablation Study
## Feature ablation
You should be able to pip install this package from GitHub but I still haven't understood why it doesn't let you do it.

In [1]:
import inspect
import torch
import torch.nn as nn
import copy
from torch.utils.data import DataLoader


class Ablator:
    def __init__(self, model, dataset, dataset_features, dataloader_kwargs, training_fn):
        self.model = model
        self.dataset = dataset
        self.dataset_features = dataset_features
        self.dataloader_kwargs = dataloader_kwargs
        self.training_fn = training_fn

        self.trials = []
        self.state_dictionary = model.state_dict()

    def ablate_layers(self, idx_list, input_shape, infer_activation=False):
        if idx_list is None:
            return copy.deepcopy(self.model)
            # Why a copy? Because if you perform a multiple feature ablation without layer ablation you train on the
            #  same model over and over again.
        if type(idx_list) == int:
            idx_list = [idx_list]
        elif type(idx_list) != list:
            raise TypeError("idx_to_ablate should be an integer or a list of integers")

        new_modules = self._get_module_list()

        if infer_activation:
            activations_idx = []
            for idx in idx_list:
                if ((idx + 1) < len(new_modules)) and self._is_activation(new_modules[idx + 1]):
                    activations_idx.append(idx + 1)
            idx_list = idx_list + activations_idx
            idx_list = list(set(idx_list))

        ablated_modules = self.remove_modules(new_modules, idx_list)
        correct_modules = self._match_model_features(ablated_modules, input_shape)
        ablated_model = nn.Sequential(*correct_modules)

        return ablated_model

    @staticmethod
    def _match_model_features(model_modules, input_shape):
        tensor_shape = (1,) + input_shape
        last_valid_out_features = tensor_shape[1]
        i = 0
        input_tensor = torch.rand(tensor_shape)
        anti_stuck_idx = 0

        while i < len(model_modules):
            layer = model_modules[i]

            try:
                output_tensor = layer(input_tensor)
                anti_stuck_idx = 0
                last_valid_out_features = output_tensor.shape[1]
                # print(layer, "\t\t", output_tensor.shape)
                i += 1
                input_tensor = output_tensor

            except RuntimeError:
                anti_stuck_idx += 1

                if anti_stuck_idx > 1:
                    raise RuntimeError("Ablation failed. Check again what modules you are ablating")

                layer_type = type(layer)
                layer_signature = inspect.signature(layer_type)
                parameters = dir(layer) & layer_signature.parameters.keys()
                new_args = dict()

                for key, value in layer.__dict__.items():
                    if key in parameters:
                        new_args[key] = value

                if "in_features" in new_args:
                    new_args["in_features"] = last_valid_out_features

                elif "in_channels" in new_args:
                    new_args["in_channels"] = last_valid_out_features

                # This new initialization is necessary because even if you change the shape of the layer,
                #  without initialization you don't have the correct number of weights
                model_modules[i] = layer_type(**new_args)
        return model_modules

    def new_trial(self, input_shape, ablated_layers=None, ablated_features=None, infer_activation=False):
        self.trials.append(Trial(input_shape, ablated_layers, ablated_features, infer_activation))

    def execute_trials(self):
        for i, trial in enumerate(self.trials):
            print("Starting trial", i)

            original_data = self.dataset.data

            # 1) Ablate layers
            ablated_model = self.ablate_layers(trial.ablated_layers, trial.input_shape, trial.infer_activation)

            # 2) Ablate features
            if trial.ablated_features is not None:
                print("Ablating features:", trial.ablated_features)
                self.dataset.ablate_feature(trial.ablated_features)

            # 3) Match features in model
            self._match_model_features(ablated_model, trial.input_shape)

            # 4) Train
            dataloader = DataLoader(self.dataset, **self.dataloader_kwargs)
            trial.metric = self.training_fn(ablated_model, dataloader)
            print("Final metric:", trial.metric, "\n\n")

            # 5) Restore original data
            self.dataset.data = original_data

    def _get_module_list(self):
        modules = []
        for mod in self.model.modules():
            modules.append(mod)
        # In PyTorch the first module is actually a description of the whole model
        modules.pop(0)
        return modules

    def remove_modules(self, modules_list, modules_to_ablate):
        for i in reversed(sorted(modules_to_ablate)):
            self._ablate_and_print(modules_list, i)
        return modules_list

    @staticmethod
    def _ablate_and_print(modules, i):
        ablated = modules.pop(i)
        print("Ablating ", i, " - ", ablated, sep="")

    @staticmethod
    def _is_activation(layer):
        from torch.nn.modules import activation
        activation_functions = inspect.getmembers(activation, inspect.isclass)
        activation_functions = [x[0] for x in activation_functions]
        if layer.__class__.__name__ in activation_functions:
            return True
        else:
            return False


class MaggyDataset:
    """
    In PyTorch there is no way to get the entire dataset starting from the classes Dataset or DataLoader.
     This is because the only method whose implementation is guaranteed is __getitem__ (enumerate) but there is
     no specification on what this method should return. For instance, it could return a row of a tabular dataset,
     as well as a tuple (label, row). For this reason we necessitate a method that returns a tabular dataset
    (tabular because we define feature ablation only on tabular datasets for now) on which we can ablate the columns.
    """
    def __init__(self, data):
        self.data = data

    def ablate_feature(self, feature):
        raise NotImplementedError


class Trial:
    def __init__(self, input_shape, ablated_layers, ablated_features, infer_activation):
        self.ablated_layers = ablated_layers
        self.ablated_features = ablated_features
        self.input_shape = input_shape
        self.infer_activation = infer_activation
        self.metric = None


In [28]:
# This should work but then I can't import anything

!pip install -e git+https://github.com/alessiomolinari/PyTorch_Ablation_Study#egg=Pytorch_Ablation_Study

ERROR: Could not detect requirement name for 'git+https://github.com/alessiomolinari/PyTorch_Ablation_Study@master', please specify one with #egg=your_package_name


In [27]:
from pytorch_ablation import ablator

### Here starts the demo

For this demo we are going to use the CIFAR10 dataset with a PyTorch sequential model. Of course in a classical feature ablation study it's quite common to eliminate one or more columns from a dataset and retrain the model to see if the performance changes.  
What if the dataset is not tabular? We could deal with images or time series and necessitate more refined ablation.  
As PyTorch allows the user a lot of freedom in defining models and datasets, this framework follows the same philosophy and lets the user decide how to ablate her own features.  
<br>
**Step 1: Get your dataset**

In [3]:
from torchvision import datasets, models, transforms
from pytorch_ablation.ablator import Ablator, MaggyDataset
from torch import nn, optim
import numpy as np

# Data preparation

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar_train = datasets.cifar.CIFAR10(root="data", train=True, download=True, transform=transform_train)
cifar_test = datasets.cifar.CIFAR10(root="data", train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


One of the requirements of the framework is to put your dataloader arguments in a dictionary rather than building the actual dataloader.  
This is necessary because different dataloaders will be dynamically built after the feature ablation trials.

In [4]:
dataloader_kwargs = {"batch_size": 128, "shuffle": True}

**Step 2: Build your model**  
<br>
At the moment the framework supports only PyTorch Sequential models.

In [5]:
model = nn.Sequential(nn.Conv2d(3, 5, 3, 1),
                     nn.ReLU(),
                     nn.Conv2d(5,7, 3, 1),
                     nn.ReLU(),
                     nn.Conv2d(7, 10, 3, 1),
                     nn.ReLU(),
                     nn.MaxPool2d(2),
                     nn.Dropout2d(0.25),
                     nn.Flatten(),
                     nn.Linear(1000, 128),
                     nn.ReLU(),
                     nn.Dropout(0.5),
                     nn.Linear(128, 10),
                     nn.LogSoftmax(dim=1)
                     )

# Hyperparameters

device = "cpu"
epochs = 3
lr = 0.001

**Step 3: Define your training function**  
<br>
The first and second argument of the training function should be:
1. Model  
2. Dataloader 

While the **returned object** is the metric that you want to store when a trial is finished.

In [6]:
def train(net, trainloader):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr,
                          momentum=0.9, weight_decay=5e-4)

    for epoch in range(epochs):
        print('\nEpoch: %d' % epoch)
        net.train()
        train_loss = 0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return loss

**Step 4: define your ablation**  
A MaggyDataset has two requirements:
1. All data should be stored in the attribute data of your PyTorch dataset
2. For every possible feature ablation you should define how you actually ablate this data (i.e. override the method ablate_feature)

For instance, in this example we have CIFAR10 and we define how to ablate the first and the second channel.  
Since the dataset was just loaded from the package torchvision and we have no access to the code of the class Dataset we need to define the function *ablate_feature* and manually attach it to CIFAR10

In [24]:
# Define custom ablation
def ablate_feature(self, feature):
    if feature == "ch1":
        self.data = np.delete(self.data, obj=0, axis=3)
    if feature == "ch2":
        self.data = np.delete(self.data, obj=1, axis=3)


# In this case you have to attach the method to the dataset because we have no access to the dataset class code

setattr(datasets.cifar.CIFAR10, "ablate_feature", ablate_feature)

In [26]:
# Why doesn't this work?

cifar_train.ablate_feature()

TypeError: ablate_feature() missing 2 required positional arguments: 'self' and 'feature'

**Last thing: define your trials**
A trial is in the form (input_shape, ablated_layers, ablate_features)  
<br>
*input_shape* is the shape of the tensor that you input your network withouth considering the batch dimension.  
*ablated_layers* is an integer or a list of integers representing the indeces of the layers that you want to ablate.  
*ablated_feature* is the feature that you want to ablate from your dataset. For now only one feature per trial is supported.  
<br>
Instantiate your ablator, add the trials and then run them with:
```
ablator.execute_trials()
```

In [23]:
ablator = Ablator(model, cifar_train, dataloader_kwargs, train)

# ablator.new_trial((3, 32, 32), None, None)
ablator.new_trial((2, 32, 32), None, "ch1")
ablator.new_trial((3, 32, 32), 2, None)

ablator.execute_trials()

Starting trial 0
Ablating features: ch1


TypeError: ablate_feature() missing 1 required positional argument: 'feature'