In [None]:
import torch
class FumbrellaMetrics():
    def __init__(self, batch_size):
        self.vector_norms = []
        self.rescaled_diffs = []
        self.avg_diff_per_class = []
        self.avg_diff_all_classes = []
        self.batch_size = batch_size

    def add(self, stage1_grad, stage2_grad):
        grad_diff = stage2_grad[0] - stage1_grad[0]
        grad_diff_rescaled = grad_diff * self.batch_size
        self.rescaled_diffs.append(grad_diff_rescaled)
        self.vector_norms.append(torch.linalg.vector_norm(grad_diff_rescaled,dim=1))
        self.avg_diff_per_class.append(grad_diff_rescaled.abs().mean(dim=0))
        self.avg_diff_all_classes.append(self.avg_diff_per_class[-1].mean())

    def compute(self, accelerator):
        diff_metrics = {}
        vector_norms = accelerator.gather_for_metrics(self.vector_norms)
        if isinstance(vector_norms, list):
            vector_norms = torch.cat(vector_norms)
        avg_grad_diffs_per_class = accelerator.gather_for_metrics(self.avg_diff_per_class)
        if isinstance(avg_grad_diffs_per_class, list):
            avg_grad_diffs_per_class = torch.stack(avg_grad_diffs_per_class).mean(dim=0)
        avg_grad_diffs_all_classes = accelerator.gather_for_metrics(self.avg_diff_all_classes)
        if isinstance(avg_grad_diffs_all_classes, list):
            avg_grad_diffs_all_classes = torch.stack(avg_grad_diffs_all_classes).mean()
        self.clear()
        diff_metrics = {
            "avg_grad_diff_all_classes" : avg_grad_diffs_all_classes,
            "avg_grad_diff_per_class" : avg_grad_diffs_per_class,
            "vector_norms" : vector_norms
        }
        return diff_metrics

    def clear(self):
        self.vector_norms = []
        self.rescaled_diffs = []
        self.avg_diff_per_class = []
        self.avg_diff_all_classes = []


class Fumbrella:
    def __init__(
            self,
            module,
            model,
            batch_size,
            stage1_dropout : float,
            stage2_dropout : float,
            position : str = 'input',
            stage = 1
            ):
        # position: 'input', 'output' 
        self.position = position
        self.module = module
        self._register_hooks()
        self._prepare_module_lists(model)
        self.dropout_rates = {
            1 : stage1_dropout,
            2 : stage2_dropout
        }
        self.set_stage(stage)
        self.stage1_grad = None
        # metrics
        self.metrics = FumbrellaMetrics(batch_size)

    def _register_hooks(self):
        if self.position == 'input':
            self.fhook = self.module.register_forward_pre_hook(self._forward_pre_hook_fn)
            self.bhook = self.module.register_full_backward_hook(self._backward_hook_fn)
        elif self.position == 'output':
            self.fhook = self.module.register_forward_hook(self._forward_hook_fn)
            self.bhook = self.module.register_full_backward_pre_hook(self._backward_pre_hook_fn)

    def _forward_pre_hook_fn(self, module, input):
        # stage 1
        # need to activate the gradient calculation
        if self.stage == 1:
            for tensor in input:
                tensor.requires_grad = True
            return input
        
    def _forward_hook_fn(self, module, input, output):
        # stage 1
        # need to activate the gradient calculation
        if self.stage == 1:
            for tensor in output:
                tensor.requires_grad = True
            return output

    def _backward_hook_fn(self, module, grad_input, grad_output):
        # stage 1
        if self.stage == 1:
            self.stage1_grad = grad_input
        # stage 2
        if self.stage == 2:
            self.metrics.add(self.stage1_grad, grad_input)
            return self.stage1_grad
              
    def _backward_pre_hook_fn(self, module, grad_input, grad_output):
        # stage 1
        if self.stage == 1:
            self.stage1_grad = grad_output
        # stage 2
        if self.stage == 2:
            self.metrics.add(self.stage1_grad, grad_output)
            return self.stage1_grad
        
    def _prepare_module_lists(self, model):
        self.dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
        self.all_parameters = [p for p in model.parameters()]
        if self.position == 'input':
            self.stage1_parameters = list(self.module.parameters())
        elif self.position == 'output':
            self.stage1_parameters = []


    def set_stage(self, stage: int):
        self.stage = stage
        for m in self.dropout_modules:
            m.p = self.dropout_rates[stage]

        for p in self.all_parameters:
            p.requires_grad = stage == 2
        for p in self.stage1_parameters:
            p.requires_grad = stage == 1

    def compute_diff_metrics(self,accelerator):
        return self.metrics.compute(accelerator)

    def close(self):
        self.fhook.remove()
        self.bhook.remove()