In [None]:
!pip install torchtext==0.4

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtext==0.4
  Downloading torchtext-0.4.0-py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 1.7 MB/s 
Installing collected packages: torchtext
  Attempting uninstall: torchtext
    Found existing installation: torchtext 0.12.0
    Uninstalling torchtext-0.12.0:
      Successfully uninstalled torchtext-0.12.0
Successfully installed torchtext-0.4.0


In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 8.0 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 55.5 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 14.1 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 87.2 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstal

In [None]:
"""
This script fine-tunes a BERT model on the SST.
"""
import torch
from torch import optim
from torchtext import data as tt
from torchtext.datasets import SST
from transformers import BertTokenizer, BertForSequenceClassification

# Load a pre-trained BERT model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                      return_dict=True,
                                                      num_labels=4)
model.to("cuda")
optimizer = optim.Adam(model.parameters(), lr=3e-5, eps=1e-8)

# Load the data
text_field = tt.RawField()
label_field = tt.Field(sequential=False)
data = SST.splits(text_field, label_field)
label_field.build_vocab(data[0])
iters = tt.BucketIterator.splits(data, batch_size=32, device="cuda")

# Train
best_accuracy = 0
for epoch in range(2):
    print("Epoch", epoch + 1)
    model.train()
    for i, batch in enumerate(iters[0]):
        model.zero_grad()

        # Forward pass
        inputs = tokenizer(batch.text, return_tensors="pt", padding=True)
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")

        output = model(**inputs, labels=batch.label)

        # Compute accuracy
        if True:
            predictions = output.logits.argmax(-1)
            num_correct = int(sum(predictions == batch.label))
            accuracy = num_correct / len(batch.label) * 100
            print("Batch {}: {}/{} correct ({:.1f}%); loss = {}".format(
                i + 1, num_correct, len(batch.label), accuracy,
                float(output.loss)))

        # Backward pass
        output.loss.backward()
        optimizer.step()

    # Dev accuracy
    model.eval()
    model.zero_grad()
    num_correct = 0
    num_total = 0
    for batch in iters[1]:
        # Forward pass
        inputs = tokenizer(batch.text, return_tensors="pt", padding=True)
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")

        output = model(**inputs, labels=batch.label)

        # Compute accuracy
        predictions = output.logits.argmax(-1)
        num_correct += int(sum(predictions == batch.label))
        num_total += len(batch.label)

    accuracy = num_correct / num_total * 100
    print("Epoch {}: {}/{} correct ({:.1f}%)".format(
        epoch + 1, num_correct, num_total, accuracy))

    if accuracy > best_accuracy:
        torch.save(model.state_dict(), "bert-sst.pt")
        best_accuracy = accuracy

# Testing
num_correct = 0
num_total = 0
for batch in iters[2]:
    # Forward pass
    inputs = tokenizer(batch.text, return_tensors="pt", padding=True)
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.to("cuda")

    output = model(**inputs, labels=batch.label)

    # Compute accuracy
    predictions = output.logits.argmax(-1)
    num_correct += int(sum(predictions == batch.label))
    num_total += len(batch.label)

accuracy = num_correct / num_total * 100
print("Test: {}/{} correct ({:.1f}%)".format(num_correct, num_total, accuracy))

torch.save(model.config, "bert-sst-config.pt")


Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

downloading trainDevTestTrees_PTB.zip


trainDevTestTrees_PTB.zip: 100%|██████████| 790k/790k [00:00<00:00, 3.57MB/s]


extracting
Epoch 1
Batch 1: 12/32 correct (37.5%); loss = 1.4057819843292236
Batch 2: 14/32 correct (43.8%); loss = 1.2313992977142334
Batch 3: 18/32 correct (56.2%); loss = 1.222074031829834
Batch 4: 15/32 correct (46.9%); loss = 1.2185885906219482
Batch 5: 18/32 correct (56.2%); loss = 1.1597216129302979
Batch 6: 16/32 correct (50.0%); loss = 1.1490731239318848
Batch 7: 12/32 correct (37.5%); loss = 1.1775858402252197
Batch 8: 19/32 correct (59.4%); loss = 1.0570167303085327
Batch 9: 14/32 correct (43.8%); loss = 1.1025631427764893
Batch 10: 17/32 correct (53.1%); loss = 1.0376204252243042
Batch 11: 21/32 correct (65.6%); loss = 1.0406761169433594
Batch 12: 19/32 correct (59.4%); loss = 1.0873944759368896
Batch 13: 16/32 correct (50.0%); loss = 1.1076548099517822
Batch 14: 18/32 correct (56.2%); loss = 1.0372604131698608
Batch 15: 17/32 correct (53.1%); loss = 1.0622031688690186
Batch 16: 14/32 correct (43.8%); loss = 1.1418291330337524
Batch 17: 16/32 correct (50.0%); loss = 1.14911

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np


def lrp_linear_single(x: np.ndarray, y: np.ndarray, rel_y: np.ndarray,
                      w: np.ndarray = None, eps: float = 0.001) -> np.ndarray:
    """
    Implements the LRP-epsilon rule for a linear layer: y = w @ x + b.
    This function only computes LRP for a single example.
    
    :param x: Layer input (input_size,)
    :param y: Layer output (output_size,)
    :param rel_y: Network output relevance (output_size,)
    :param w: Weight matrix (input_size, output_size). If left blank,
        the weight matrix is assumed to be the identity: y = x + b
    :param eps: Stabilizer
    :return: The relevance of x (input_size,)
    """
    y = y + eps * np.where(y >= 0, 1., -1.)
    if w is None:
        return x * (rel_y / y)
    return (w * x[:, np.newaxis]) @ (rel_y / y)


def lrp_linear(x: np.ndarray, y: np.ndarray, rel_y: np.ndarray,
               w: np.ndarray = None, eps: float = 0.001) -> np.ndarray:
    """
    Implements the LRP-epsilon rule for a linear layer: y = w @ x + b.
    :param x: Input (..., input_size)
    :param y: Output (..., output_size)
    :param rel_y: Network output relevance (..., output_size)
    :param w: Transposed weight matrix (input_size, output_size). If
        left blank, the weight matrix is assumed to be the identity:
        y = x + b
    :param eps: Stabilizer
    :return: The relevance of x (batch_size, input_size)
    """
    y = y + eps * np.where(y >= 0, 1., -1.)
    if w is None:
        return x * (rel_y / y)

    lhs = w[..., np.newaxis, :, :] * x[..., np.newaxis]
    rhs = (rel_y / y)[..., np.newaxis]
    return (lhs @ rhs).squeeze(-1)


def lrp_matmul(x: np.ndarray, w: np.ndarray, y: np.ndarray, rel_y: np.ndarray,
               eps: float = 0.001) -> np.ndarray:
    """
    LRP-epsilon for a matrix multiplication layer: y = w @ x. One of the
    two matrices is treated as a weight matrix. All matrices are assumed
    to be batched.
    :param x: Input (..., m, n)
    :param w: Weight matrix (..., p, m)
    :param y: Output (..., p, n)
    :param rel_y: Output relevance (..., p, n)
    :param eps: Stabilizer
    :return: The relevance of x (..., m, n)
    """
    w = w.swapaxes(-1, -2)
    y = y + eps * np.where(y >= 0, 1., -1.)

    lhs = np.moveaxis(w[..., np.newaxis] * x[..., np.newaxis, :], -1, -3)
    rhs = (rel_y / y).swapaxes(-1, -2)[..., np.newaxis]
    return (lhs @ rhs).squeeze(-1).swapaxes(-1, -2)

In [None]:
from abc import ABC, abstractmethod
from typing import Tuple

import numpy as np
import torch
from scipy.special import expit
from torch import nn


class BackpropModuleMixin(ABC):
    """
    The general interface for modules with custom backward passes. When
    a module is in "attribution mode," the forward and backward
    functions are replaced with the custom functions (re-)implemented in
    NumPy.
    """

    def __init__(self, *args, **kwargs):
        super(BackpropModuleMixin, self).__init__(*args, **kwargs)
        self.attr_mode = False  # Am I in attribution mode?
        self._input = None  # Forward pass input
        self._output = None  # Forward pass output
        self._state = None  # Forward pass stored computations

    def train(self, *args, **kwargs):
        super(BackpropModuleMixin, self).train(*args, **kwargs)
        self.attr_mode = False
        self._input = None
        self._output = None
        self._state = None

    def eval(self):
        super(BackpropModuleMixin, self).eval()
        self.attr_mode = False
        self._input = None
        self._output = None
        self._state = None

    def attr(self):
        """
        Puts the module in attribution mode.
        """
        self.attr_mode = True

    _convert_attr_input_to_numpy = True

    def __call__(self, *args, **kwargs):
        if self.attr_mode:
            if self._convert_attr_input_to_numpy:
                args = list(args)
                for i in range(len(args)):
                    if isinstance(args[i], torch.Tensor):
                        args[i] = args[i].detach().numpy()

                for k in kwargs:
                    if isinstance(kwargs[k], torch.Tensor):
                        kwargs[k] = kwargs[k].detach().numpy()

            return self.attr_forward(*args, **kwargs)
        return super(BackpropModuleMixin, self).__call__(*args, **kwargs)

    def backward(self, *args, **kwargs):
        if self.attr_mode:
            args = list(args)
            for i in range(len(args)):
                if isinstance(args[i], torch.Tensor):
                    args[i] = args[i].detach().numpy()

            for k in kwargs:
                if isinstance(kwargs[k], torch.Tensor):
                    kwargs[k] = kwargs[k].detach().numpy()

            return self.attr_backward(*args, **kwargs)
        self.backward(*args, **kwargs)

    @abstractmethod
    def attr_forward(self, *args, **kwargs):
        """
        The custom forward pass in NumPy.
        """
        raise NotImplementedError("attr_forward not implemented")

    @abstractmethod
    def attr_backward(self, *args, **kwargs):
        """
        The custom backward pass in NumPy.
        """
        raise NotImplementedError("attr_backward not implemented")


class BackpropLinear(BackpropModuleMixin, nn.Linear):
    """
    An interface for nn.Linear.
    """

    def attr_forward(self, x: np.ndarray):
        self._input = [x]
        wx = x @ self.weight.detach().numpy().T
        self._state = dict(wx=wx)
        self._output = wx + self.bias.detach().numpy()
        return self._output


class BackpropRNNMixin(BackpropModuleMixin):
    """
    An interface for PyTorch RNNs in general.
    """

    def __init__(self, *args, **kwargs):
        super(BackpropRNNMixin, self).__init__(*args, **kwargs,
                                               batch_first=True)

    def attr_forward(self, x: np.ndarray):
        """
        Computes the RNN forward pass for all layers and directions. The
        mathematical calculations are defined in the abstract helper
        function _layer_forward.
        :param x: An input to the RNN, of shape (batch_size, seq_len,
            input_size)
        :return: The RNN output, of shape (batch_size, seq_len,
            hidden_size)
        """
        curr_input = x
        self._input = [None] * self.num_layers
        self._state = dict(ltr=[None] * self.num_layers)
        if self.bidirectional:
            self._state["rtl"] = [None] * self.num_layers

        for l in range(self.num_layers):
            self._layer_forward(curr_input, l, 0)
            if self.bidirectional:
                self._layer_forward(np.flip(curr_input, 1), l, 1)
                h_rev = np.flip(self._state["rtl"][l][0], 1)
                curr_input = np.concatenate((self._state["ltr"][l][0], h_rev),
                                            -1)
            else:
                curr_input = self._state["ltr"][l][0]

        self._output = curr_input
        return curr_input

    @abstractmethod
    def _layer_forward(self, x: np.ndarray, layer: int, direction: int):
        """
        This helper function computes the forward pass for a particular
        layer and direction.
        :param x: The input to the layer
        :param layer: The layer number
        :param direction: The direction number (0 for left to right, 1
            for right to left)
        :return: None, but the result should be stored in
            self._state["ltr"][layer] or self._state["rtl"][layer]
        """
        raise NotImplementedError("_layer_forward not implemented")

    num_gates = None

    def _params_numpy(self, prefix: str, layer: int, direction: int) \
            -> Tuple[np.ndarray, ...]:
        """
        Retrieves weight matrices or bias vectors for a particular layer
        and direction.
        :param prefix: "weight_ih" for input weights, "weight_hh" for
            hidden state weights, "bias_ih" for input biases, or
            "bias_hh" for hidden state biases
        :param layer: The layer to retrieve weights for
        :param direction: The direction to retrieve weights for
        :return: The weight/bias matrices
        """
        p = prefix + "_l" + str(layer) + ("_reverse" if direction == 1 else "")
        return np.split(getattr(self, p).detach().numpy(), self.num_gates)


class BackpropLSTM(BackpropRNNMixin, nn.LSTM):
    """
    An interface for nn.LSTM.
    """

    num_gates = 4

    def _layer_forward(self, x: np.ndarray, layer: int, direction: int):
        if direction == 0:
            self._input[layer] = x

        batch_size, seq_len, _ = x.shape
        x = x[:, :, :, np.newaxis]

        # Get parameters
        kwargs = {"layer": layer, "direction": direction}
        w_ii, w_if, w_ig, w_io = self._params_numpy("weight_ih", **kwargs)
        w_hi, w_hf, w_hg, w_ho = self._params_numpy("weight_hh", **kwargs)
        biases_i = self._params_numpy("bias_ih", **kwargs)
        biases_h = self._params_numpy("bias_hh", **kwargs)
        b_ii, b_if, b_ig, b_io = [b[:, np.newaxis] for b in biases_i]
        b_hi, b_hf, b_hg, b_ho = [b[:, np.newaxis] for b in biases_h]

        # Initialize
        h = np.zeros((batch_size, seq_len, self.hidden_size))
        i = np.zeros((batch_size, seq_len, self.hidden_size))
        f = np.zeros((batch_size, seq_len, self.hidden_size))
        g_pre = np.zeros((batch_size, seq_len, self.hidden_size))
        g = np.zeros((batch_size, seq_len, self.hidden_size))
        o = np.zeros((batch_size, seq_len, self.hidden_size))
        c = np.zeros((batch_size, seq_len, self.hidden_size))

        # Forward pass
        h_prev = np.zeros((batch_size, self.hidden_size, 1))
        c_prev = np.zeros((batch_size, self.hidden_size))
        for t in range(seq_len):
            i_temp = (w_ii @ x[:, t] + b_ii + w_hi @ h_prev + b_hi).squeeze(-1)
            f_temp = (w_if @ x[:, t] + b_if + w_hf @ h_prev + b_hf).squeeze(-1)
            g_temp = (w_ig @ x[:, t] + b_ig + w_hg @ h_prev + b_hg).squeeze(-1)
            o_temp = (w_io @ x[:, t] + b_io + w_ho @ h_prev + b_ho).squeeze(-1)

            i[:, t] = expit(i_temp)
            f[:, t] = expit(f_temp)
            g_pre[:, t] = g_temp
            g[:, t] = np.tanh(g_temp)
            o[:, t] = expit(o_temp)

            c[:, t] = f[:, t] * c_prev + i[:, t] * g[:, t]
            h[:, t] = o[:, t] * np.tanh(c[:, t])

            h_prev = h[:, t, :, np.newaxis]
            c_prev = c[:, t]

        # Save trace to state
        if direction == 0:
            self._state["ltr"][layer] = h, c, i, f, g, g_pre, w_ig.T, w_hg.T
        else:
            self._state["rtl"][layer] = h, c, i, f, g, g_pre, w_ig.T, w_hg.T


class BackpropGRU(BackpropRNNMixin, nn.GRU):
    """
    An interface for nn.GRU.
    """

    num_gates = 3

    def _layer_forward(self, x: np.ndarray, layer: int, direction: int):
        if direction == 0:
            self._input[layer] = x

        batch_size, seq_len, _ = x.shape
        x = x[:, :, :, np.newaxis]

        # Get parameters
        kwargs = {"layer": layer, "direction": direction}
        w_ir, w_iz, w_in = self._params_numpy("weight_ih", **kwargs)
        w_hr, w_hz, w_hn = self._params_numpy("weight_hh", **kwargs)
        biases_i = self._params_numpy("bias_ih", **kwargs)
        biases_h = self._params_numpy("bias_hh", **kwargs)
        b_ir, b_iz, b_in = [b[:, np.newaxis] for b in biases_i]
        b_hr, b_hz, b_hn = [b[:, np.newaxis] for b in biases_h]

        # Initialize
        h = np.zeros((batch_size, seq_len, self.hidden_size))
        r = np.zeros((batch_size, seq_len, self.hidden_size))
        z = np.zeros((batch_size, seq_len, self.hidden_size))
        n_pre = np.zeros((batch_size, seq_len, self.hidden_size))
        n = np.zeros((batch_size, seq_len, self.hidden_size))

        # Forward pass
        h_prev = np.zeros((batch_size, self.hidden_size, 1))
        for t in range(seq_len):
            r_temp = (w_ir @ x[:, t] + b_ir + w_hr @ h_prev + b_hr).squeeze(-1)
            z_temp = (w_iz @ x[:, t] + b_iz + w_hz @ h_prev + b_hz).squeeze(-1)
            r[:, t] = expit(r_temp)
            z[:, t] = expit(z_temp)

            n_pre_i = (w_in @ x[:, t] + b_in).squeeze(-1)
            n_pre_h = (w_hn @ h_prev + b_hn).squeeze(-1)
            n_pre[:, t] = n_pre_i + r[:, t] * n_pre_h
            n[:, t] = np.tanh(n_pre[:, t])

            h[:, t] = (1 - z[:, t]) * n[:, t] + z[:, t] * h_prev.squeeze(-1)
            h_prev = h[:, t, :, np.newaxis]

            # Save trace to state
            if direction == 0:
                self._state["ltr"][layer] = h, r, z, n, n_pre, w_in.T, w_hn.T
            else:
                self._state["rtl"][layer] = h, r, z, n, n_pre, w_in.T, w_hn.T


class BackpropLayerNorm(BackpropModuleMixin, nn.LayerNorm):
    """
    Layer normalization for the Transformer.
    """

    def attr_forward(self, x):
        axes = tuple(range(-1, -len(self.normalized_shape) - 1, -1))
        mean = x.mean(axis=axes, keepdims=True)
        num = x - mean
        den = np.sqrt(x.var(axis=axes, keepdims=True) + self.eps)

        self._state = dict(mean=mean, x=x)
        if not self.elementwise_affine:
            return num / den

        gamma = self.weight.detach().numpy()
        beta = self.bias.detach().numpy()
        gamma_term = (num / den) * gamma
        output = gamma_term + beta

        self._state["output"] = output
        self._state["gamma_term"] = gamma_term
        return output

In [None]:
from typing import List, Tuple, Union

import numpy as np
import scipy.special as sp
import torch
from torch import nn
from transformers import BertConfig
from transformers.models.bert import modeling_bert as bert

# from interpret_nlp.modules import backprop_module as bp

# Shorthands for different array sizes
HiddenArray = np.ndarray  # (batch_size, seq_len, hidden_size)
AttentionArray = np.ndarray  # (batch_size, num_heads, seq_len, _)
IndexTensor = torch.LongTensor  # (batch_size, seq_len)
EmbeddingTensor = torch.FloatTensor  # (batch_size, seq_len, hidden_size)

NormalLayer = Union[nn.Linear, nn.LayerNorm]
AttnSubLayer = Union[bert.BertSelfAttention, bert.BertSelfOutput]

_erf_approx = lambda x: np.tanh(np.sqrt(2. / np.pi) * (x + 0.044715 * x ** 3))
activations = dict(relu=lambda x: np.maximum(x, 0.),
                   gelu=lambda x: x * .5 * (1. + sp.erf(x / np.sqrt(2.))),
                   swish=lambda x: x * sp.expit(x),
                   gelu_new=lambda x: x * .5 * _erf_approx(x),
                   mish=None)


def hidden_to_attention(h: HiddenArray, num_heads: int) -> AttentionArray:
    return h.reshape(h.shape[:-1] + (num_heads, -1)).transpose(0, 2, 1, 3)


def attention_to_hidden(a: AttentionArray) -> HiddenArray:
    a = a.transpose(0, 2, 1, 3)
    return a.reshape(a.shape[:-2] + (-1,))


class BackpropBertMixin(BackpropModuleMixin):
    """
    Interface for BERT modules with custom backprop. This mixin
    introduces a function that converts a normal PyTorch module to a
    custom backprop module.
    """

    _layer_types = {nn.Linear: BackpropLinear,
                    nn.LayerNorm: BackpropLayerNorm}
    _bert_layer_types = {bert.BertSelfAttention: None,
                         bert.BertSelfOutput: None,
                         bert.BertAttention: None,
                         bert.BertIntermediate: None,
                         bert.BertOutput: None}

    def convert_to_attr(self, layer: NormalLayer) -> BackpropModuleMixin:
        """
        Converts nn.Linear or nn.LayerNorm to custom backprop layers. In
        order to use this, child classes must override the _layer_types
        dict defined above.
        :param layer: An nn.Linear or nn.LayerNorm module
        :return: The corresponding module with custom backprop
        """
        if isinstance(layer, nn.Linear):
            linear_class = self._layer_types[nn.Linear]
            new_layer = linear_class(layer.in_features, layer.out_features)
        elif isinstance(layer, nn.LayerNorm):
            ln_class = self._layer_types[nn.LayerNorm]
            new_layer = ln_class(layer.normalized_shape, eps=layer.eps,
                                 elementwise_affine=layer.elementwise_affine)
        else:
            raise TypeError("Cannot convert layer of type " + str(type(layer)))

        new_layer.load_state_dict(layer.state_dict())
        return new_layer

    def convert_bert_to_attr(self, layer: AttnSubLayer, config: BertConfig):
        new_layer = self._bert_layer_types[type(layer)](config)
        new_layer.load_state_dict(layer.state_dict())
        return new_layer

    def hidden_to_attention(self, h: HiddenArray) -> AttentionArray:
        return hidden_to_attention(h, self.num_attention_heads)

    @staticmethod
    def attention_to_hidden(a: AttentionArray) -> HiddenArray:
        return attention_to_hidden(a)


class BackpropBertEmbeddings(BackpropBertMixin, bert.BertEmbeddings):
    """
    Combines word embeddings with positional embeddings and token type
    embeddings.
    """

    _convert_attr_input_to_numpy = False

    def __init__(self, config: BertConfig):
        super(BackpropBertEmbeddings, self).__init__(config)
        self.LayerNorm = self.convert_to_attr(self.LayerNorm)

    def attr(self):
        super(BackpropBertEmbeddings, self).attr()
        self.LayerNorm.attr()

    def attr_forward(self, input_ids: IndexTensor = None,
                     inputs_embeds: EmbeddingTensor = None,
                     token_type_ids: IndexTensor = None,
                     position_ids: IndexTensor = None) -> HiddenArray:
        """
        Adds the word embeddings to positional and token type
        embeddings.
        :param input_ids: Indices for an input sequence
        :param inputs_embeds: Embeddings for an input sequence. Either
            this or input_ids must be none
        :param token_type_ids: idk what this is
        :param position_ids: Positions
        :return: The input to the first BERT layer
        """
        assert (input_ids is None) != (inputs_embeds is None)

        if input_ids is not None:
            input_shape = input_ids.shape
            inputs_embeds = self.word_embeddings(input_ids).detach().numpy()
        else:
            input_shape = inputs_embeds.shape[:-1]
            inputs_embeds = inputs_embeds.detach().numpy()

        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]
        position_embeds = self.position_embeddings(position_ids)
        position_embeds = position_embeds.detach().numpy()

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long,
                                         device=self.position_ids.device)
        token_type_embeds = self.token_type_embeddings(token_type_ids)
        token_type_embeds = token_type_embeds.detach().numpy()

        self._state = inputs_embeds, position_embeds, token_type_embeds
        return self.LayerNorm(inputs_embeds + position_embeds +
                              token_type_embeds)


class BackpropBertSelfAttention(BackpropBertMixin, bert.BertSelfAttention):
    """
    A BERT self-attention module. This module is responsible for
    implementing the scaled dot-product attention equation. This module
    is combined with BackpropBertSelfOutput to form an attention layer.
    """

    def __init__(self, config: BertConfig):
        super(BackpropBertSelfAttention, self).__init__(config)
        self.query = self.convert_to_attr(self.query)
        self.key = self.convert_to_attr(self.key)
        self.value = self.convert_to_attr(self.value)

    def attr(self):
        super(BackpropBertSelfAttention, self).attr()
        self.query.attr()
        self.key.attr()
        self.value.attr()

    def attr_forward(self, hidden_states: HiddenArray,
                     attention_mask: AttentionArray = None,
                     head_mask: AttentionArray = None,
                     encoder_hidden_states: HiddenArray = None,
                     encoder_attention_mask: AttentionArray = None) -> \
            Tuple[HiddenArray, AttentionArray]:
        """
        Implements the scaled dot-product attention equation.
        :param hidden_states: The input to the attention layer
        :param attention_mask: The attention mask
        :param head_mask: An optional mask for heads
        :param encoder_hidden_states: The encoder hidden states, passed
            to this layer when used in a decoder
        :param encoder_attention_mask: None
        :return: The result of the attention equation and the attention
            probabilities
        """
        mixed_query_layer = self.query(hidden_states)
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

        query_layer = self.hidden_to_attention(mixed_query_layer)
        key_layer = self.hidden_to_attention(mixed_key_layer)
        value_layer = self.hidden_to_attention(mixed_value_layer)

        attention_scores = query_layer @ key_layer.transpose(0, 1, 3, 2)
        attention_scores /= np.sqrt(self.attention_head_size)
        if attention_mask is not None:
            attention_scores += attention_mask

        attention_probs = sp.softmax(attention_scores, axis=-1)
        if head_mask is not None:
            attention_probs *= head_mask

        context_layer = attention_probs @ value_layer
        self._state = dict(context_layer=context_layer,
                           attention_probs=attention_probs,
                           value_layer=value_layer)

        context_layer = self.attention_to_hidden(context_layer)
        return context_layer, attention_probs


class BackpropBertSelfOutput(BackpropBertMixin, bert.BertSelfOutput):
    """
    Implements the attention heads and add-and-norm portion of a self-
    attention layer. This layer is used with BackpropBertSelfAttention.
    """

    def __init__(self, config: BertConfig):
        super(BackpropBertSelfOutput, self).__init__(config)
        self.dense = self.convert_to_attr(self.dense)
        self.LayerNorm = self.convert_to_attr(self.LayerNorm)

    def attr(self):
        super(BackpropBertSelfOutput, self).attr()
        self.dense.attr()
        self.LayerNorm.attr()

    def attr_forward(self, hidden_states: HiddenArray,
                     input_tensor: HiddenArray) -> HiddenArray:
        dense_output = self.dense(hidden_states)
        self._state = dict(dense_output=dense_output,
                           input_tensor=input_tensor)
        return self.LayerNorm(dense_output + input_tensor)


class BackpropBertAttention(BackpropBertMixin, bert.BertAttention):
    """
    A complete self-attention layer, which combines
    BackpropBertSelfAttention with BackpropBertSelfOutput.
    """

    _bert_layer_types = {bert.BertSelfAttention: BackpropBertSelfAttention,
                         bert.BertSelfOutput: BackpropBertSelfOutput}

    def __init__(self, config: BertConfig):
        super(BackpropBertAttention, self).__init__(config)
        self.self = self.convert_bert_to_attr(self.self, config)
        self.output = self.convert_bert_to_attr(self.output, config)

    def attr(self):
        super(BackpropBertAttention, self).attr()
        self.self.attr()
        self.output.attr()

    def attr_forward(self, hidden_states: HiddenArray,
                     attention_mask: AttentionArray = None,
                     head_mask: AttentionArray = None,
                     encoder_hidden_states: HiddenArray = None,
                     encoder_attention_mask: AttentionArray = None) -> \
            Tuple[HiddenArray, AttentionArray]:
        """
        :param hidden_states: The attention layer input (batch_size,
            seq_len, hidden_size)
        :param attention_mask: The attention mask
        :param head_mask:
        :param encoder_hidden_states:
        :param encoder_attention_mask:
        :return: The attention layer output
        """
        self_outputs = self.self(hidden_states, attention_mask=attention_mask,
                                 head_mask=head_mask,
                                 encoder_hidden_states=encoder_hidden_states,
                                 encoder_attention_mask=encoder_attention_mask)
        return self.output(self_outputs[0], hidden_states), self_outputs[1]


class BackpropBertIntermediate(BackpropBertMixin, bert.BertIntermediate):
    """
    Implements the first linear layer after the self-attention layer.
    """

    def __init__(self, config: BertConfig):
        super(BackpropBertIntermediate, self).__init__(config)
        self.dense = self.convert_to_attr(self.dense)
        self.intermediate_act_fn_numpy = activations[config.hidden_act]

    def attr(self):
        super(BackpropBertIntermediate, self).attr()
        self.dense.attr()

    def attr_forward(self, hidden_states: HiddenArray) -> HiddenArray:
        return self.intermediate_act_fn_numpy(self.dense(hidden_states))


class BackpropBertOutput(BackpropBertMixin, bert.BertOutput):
    """
    Implements the final linear and add-and-norm layers of a Transformer
    encoder/decoder block.
    """

    def __init__(self, config: BertConfig):
        super(BackpropBertOutput, self).__init__(config)
        self.dense = self.convert_to_attr(self.dense)
        self.LayerNorm = self.convert_to_attr(self.LayerNorm)

    def attr(self):
        super(BackpropBertOutput, self).attr()
        self.dense.attr()
        self.LayerNorm.attr()

    def attr_forward(self, hidden_states: HiddenArray,
                     input_tensor: HiddenArray) -> HiddenArray:
        dense_output = self.dense(hidden_states)
        self._state = dict(dense_output=dense_output,
                           input_tensor=input_tensor)
        return self.LayerNorm(dense_output + input_tensor)


class BackpropBertLayer(BackpropBertMixin, bert.BertLayer):
    """
    A full BERT encoder or decoder block.
    """
    _bert_layer_types = {bert.BertAttention: BackpropBertAttention,
                         bert.BertIntermediate: BackpropBertIntermediate,
                         bert.BertOutput: BackpropBertOutput}

    def __init__(self, config: BertConfig):
        super(BackpropBertLayer, self).__init__(config)
        self.attention = self.convert_bert_to_attr(self.attention, config)
        if self.add_cross_attention:
            self.crossattention = self.convert_bert_to_attr(
                self.crossattention, config)
        self.intermediate = self.convert_bert_to_attr(self.intermediate,
                                                      config)
        self.output = self.convert_bert_to_attr(self.output, config)

    def attr(self):
        super(BackpropBertLayer, self).attr()
        self.attention.attr()
        if self.add_cross_attention:
            self.crossattention.attr()
        self.intermediate.attr()
        self.output.attr()

    def attr_forward(self, hidden_states: HiddenArray,
                     attention_mask: AttentionArray = None,
                     head_mask: AttentionArray = None,
                     encoder_hidden_states: HiddenArray = None,
                     encoder_attention_mask: AttentionArray = None) -> \
            Tuple[HiddenArray, AttentionArray]:
        """
        The complete forward pass for a full encoder or decoder block.
        :param hidden_states: The input to the encoder or decoder block
        :param attention_mask: The attention mask
        :param head_mask: The head mask
        :param encoder_hidden_states: Hidden states from the encoder, if
            this is a decoder block
        :param encoder_attention_mask: The encoder attention mask, if
            this is a decoder block
        :return: The output of this block, along with the attention
            scores
        """
        assert self.attention.attr_mode
        assert self.intermediate.attr_mode
        assert self.output.attr_mode
        if self.add_cross_attention:
            assert self.crossattention.attr_mode

        attn_output, attn_probs = self.attention(hidden_states,
                                                 attention_mask=attention_mask,
                                                 head_mask=head_mask)

        if self.is_decoder and encoder_hidden_states is not None:
            self._state = {"crossattention_used": True}
            assert hasattr(self, "crossattention")
            assert self.crossattention.attr_mode
            cross_output = self.crossattention(attn_output,
                                               attention_mask, head_mask,
                                               encoder_hidden_states,
                                               encoder_attention_mask)
            attn_output, attn_probs = cross_output
        else:
            self._state = {"crossattention_used": False}

        # TODO: Make apply_chunking_to_forward compatible with NumPy
        output = bert.apply_chunking_to_forward(self.feed_forward_chunk,
                                                self.chunk_size_feed_forward,
                                                self.seq_len_dim, attn_output)

        return output, attn_probs


class BackpropBertEncoder(BackpropBertMixin, bert.BertEncoder):
    """
    A BERT encoder, consisting of multiple encoder blocks.
    """

    _bert_layer_types = {bert.BertLayer: BackpropBertLayer}

    def __init__(self, config: BertConfig):
        super(BackpropBertEncoder, self).__init__(config)
        layers = [self.convert_bert_to_attr(e, config) for e in self.layer]
        self.layer = nn.ModuleList(layers)

    def attr(self):
        super(BackpropBertEncoder, self).attr()
        for e in self.layer:
            e.attr()

    def attr_forward(self, hidden_states: HiddenArray,
                     attention_mask: AttentionArray = None,
                     head_mask: AttentionArray = None,
                     encoder_hidden_states: HiddenArray = None,
                     encoder_attention_mask: AttentionArray = None) -> \
            Tuple[HiddenArray, List[HiddenArray], List[AttentionArray]]:
        """
        A full BERT encoder, consisting of multiple encoder blocks.
        :param hidden_states: The combined word, position, and token
            type embeddings
        :param attention_mask: The attention mask
        :param head_mask: The head mask
        :param encoder_hidden_states: ???
        :param encoder_attention_mask: ???
        :return: The output of the last layer, along with the outputs
            and attention scores of all layers
        """
        all_hidden_states = []
        all_attentions = []

        for i, e in enumerate(self.layer):
            all_hidden_states.append(hidden_states)

            # TODO: Add gradient checkpointing
            layer_outputs = e(hidden_states, attention_mask=attention_mask,
                              head_mask=head_mask[i],
                              encoder_hidden_states=encoder_hidden_states)

            hidden_states = layer_outputs[0]
            all_attentions.append(layer_outputs[1])

        return hidden_states, all_hidden_states, all_attentions


class BackpropBertPooler(BackpropBertMixin, bert.BertPooler):
    """
    A layer that "pools" the BERT output by passing the CLS output
    through a tanh.
    """

    def __init__(self, config: BertConfig):
        super(BackpropBertPooler, self).__init__(config)
        self.dense = self.convert_to_attr(self.dense)

    def attr(self):
        super(BackpropBertPooler, self).attr()
        self.dense.attr()

    def attr_forward(self, hidden_states: HiddenArray) -> np.ndarray:
        return np.tanh(self.dense(hidden_states[:, 0]))


class BackpropBertModel(BackpropBertMixin, bert.BertModel):
    """
    A full BERT model. This is a stack of Transformer encoders that
    takes an input sequence of the form
        [CLS] sequence1 [SEP] sequence2
    and produces an output sequence of the same form. It is pre-trained
    on BERT's masked language modeling objective.
    """
    _bert_layer_types = {bert.BertEmbeddings: BackpropBertEmbeddings,
                         bert.BertEncoder: BackpropBertEncoder,
                         bert.BertPooler: BackpropBertPooler}

    _convert_attr_input_to_numpy = False

    def __init__(self, config: BertConfig):
        super(BackpropBertModel, self).__init__(config)
        self.embeddings = self.convert_bert_to_attr(self.embeddings, config)
        self.encoder = self.convert_bert_to_attr(self.encoder, config)
        self.pooler = self.convert_bert_to_attr(self.pooler, config)

    def attr(self):
        super(BackpropBertModel, self).attr()
        self.embeddings.attr()
        self.encoder.attr()
        self.pooler.attr()

    def attr_forward(self, input_ids=None, attention_mask=None,
                     token_type_ids=None, position_ids=None, head_mask=None,
                     inputs_embeds=None, encoder_hidden_states=None,
                     encoder_attention_mask=None):
        """
        The complete BERT forward pass.
        :param input_ids: An input sequence, represented as an index
            tensor of shape (batch_size, seq_len)
        :param attention_mask: An attention mask that masks out [PAD]
            symbols and symbols without a prediction
        :param token_type_ids: Not sure what this is for
        :param position_ids: The positional encoding
        :param head_mask: Some other mask
        :param inputs_embeds: Embedding vectors for the input. This
            cannot be specified if input_ids is specified, and vice
            versa
        :param encoder_hidden_states: Hidden states from a previous
            computation, which will be reused
        :param encoder_attention_mask: The attention mask from a
            previous computation, which will be reused
        :return: The sequence output, the pooled output, and all the
            encoder block outputs
        """
        # Get input embedding shape
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and "
                             "inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.shape
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.shape[:-1]
        else:
            raise ValueError("You have to specify either input_ids or "
                             "inputs_embeds")

        if attention_mask is None:
            attention_mask = torch.ones(input_shape)
        if token_type_ids is None:
            token_type_ids = np.zeros(input_shape, dtype="int64")

        # Not really sure what this is for
        extended_attention_mask = \
            self.get_extended_attention_mask(attention_mask, input_shape,
                                             torch.device("cpu"))
        extended_attention_mask = extended_attention_mask.detach().numpy()

        if self.config.is_decoder and encoder_hidden_states is not None:
            raise RuntimeWarning("I didn't implement this carefully")
            encoder_hidden_shape = encoder_hidden_states.shape[:-1]
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape)
            encoder_extended_attn_mask = \
                self.invert_attention_mask(encoder_attention_mask)
            encoder_extended_attn_mask = \
                encoder_extended_attn_mask.detach().numpy()
        else:
            encoder_extended_attn_mask = None

        head_mask = self.get_head_mask(head_mask,
                                       self.config.num_hidden_layers)

        # Begin forward pass
        embedding_output = self.embeddings(input_ids=input_ids,
                                           position_ids=position_ids,
                                           token_type_ids=token_type_ids,
                                           inputs_embeds=inputs_embeds)

        encoder_outputs = \
            self.encoder(embedding_output,
                         attention_mask=extended_attention_mask,
                         head_mask=head_mask,
                         encoder_hidden_states=encoder_hidden_states,
                         encoder_attention_mask=encoder_extended_attn_mask)

        sequence_output = encoder_outputs[0]
        self._state = {"output_shape": sequence_output.shape}
        pooled_output = self.pooler(sequence_output)
        return (sequence_output, pooled_output) + encoder_outputs[1:]


BFSC = bert.BertForSequenceClassification


class BackpropBertForSequenceClassification(BackpropBertMixin, BFSC):
    """
    A BERT model with a linear decoder.
    """
    _bert_layer_types = {bert.BertModel: BackpropBertModel}

    _convert_attr_input_to_numpy = False

    def __init__(self, config: BertConfig):
        super(BackpropBertForSequenceClassification, self).__init__(config)
        self.bert = self.convert_bert_to_attr(self.bert, config)
        self.classifier = self.convert_to_attr(self.classifier)

    def attr(self):
        super(BackpropBertForSequenceClassification, self).attr()
        self.bert.attr()
        self.classifier.attr()

    def attr_forward(self, **kwargs):
        outputs = self.bert(**kwargs)
        return self.classifier(outputs[1])

In [None]:
import operator
from abc import ABC, abstractmethod
from functools import reduce

import numpy as np



class LRPLinear(BackpropLinear):
    """
    A Linear module with LRP.
    """

    def attr_backward(self, rel_y: np.ndarray,
                      eps: float = 0.001) -> np.ndarray:
        return lrp_linear(self._input[0], self._state["wx"], rel_y,
                          self.weight.detach().numpy().T, eps=eps)


class LRPRNNMixin(ABC):
    """
    An interface for RNNs with LRP.
    """

    def attr_backward(self, rel_y: np.ndarray,
                      eps: float = 0.001) -> np.ndarray:
        """
        Computes the LRP backward pass using the helper function
        _layer_backward.
        :param rel_y: The relevance of the RNN output, of shape
            (batch_size, seq_len, hidden_size)
        :param eps: The LRP stabilizer
        :return: The relevance of the stored input
        """
        if self.bidirectional:
            curr_rel, curr_rel_rev = np.split(rel_y, 2, axis=-1)
            curr_rel_rev = np.flip(curr_rel_rev, 1)
        else:
            curr_rel = rel_y
            curr_rel_rev = None

        for l in reversed(range(self.num_layers)):
            rel_x = self._layer_backward(curr_rel, l, 0, eps=eps)
            if self.bidirectional:
                rel_x_rev = self._layer_backward(curr_rel_rev, l, 1, eps=eps)
                rel_x += np.flip(rel_x_rev, 1)

            if self.bidirectional and l > 0:
                curr_rel = rel_x[:, :, :self.hidden_size]
                curr_rel_rev = np.flip(rel_x[:, :, self.hidden_size:], 1)
            else:
                curr_rel = rel_x

        return curr_rel

    @abstractmethod
    def _layer_backward(self, rel_y: np.ndarray, layer: int, direction: int,
                        eps: float = 0.001) -> np.ndarray:
        raise NotImplementedError("_layer_backward not implemented")


class LRPLSTM(LRPRNNMixin, BackpropLSTM):
    """
    An LSTM module with LRP.
    """

    def _layer_backward(self, rel_y: np.ndarray, layer: int, direction: int,
                        eps: float = 0.001) -> np.ndarray:
        """
        Performs a backward pass using numpy operations for one layer.
        :param rel_y: The relevance flowing to this layer
        :param layer: The layer to perform the backward pass for
        :param direction: The direction to perform the backward pass for
        :return: The relevance of the layer inputs
        """
        if direction == 0:
            x = self._input[layer]
            h, c, i, f, g, g_pre, w_ig, w_hg = self._state["ltr"][layer]
        else:
            x = np.flip(self._input[layer], 1)
            h, c, i, f, g, g_pre, w_ig, w_hg = self._state["rtl"][layer]

        batch_size, seq_len, _ = x.shape

        # Initialize
        rel_h = np.zeros((batch_size, seq_len + 1, self.hidden_size))
        rel_c = np.zeros((batch_size, seq_len + 1, self.hidden_size))
        rel_g = np.zeros(g.shape)
        rel_x = np.zeros(x.shape)

        # Backward pass
        rel_h[:, 1:] = rel_y
        for t in reversed(range(seq_len)):
            rel_c[:, t + 1] += rel_h[:, t + 1]
            rel_c[:, t] = lrp_linear(f[:, t] * c[:, t - 1], c[:, t],
                                     rel_c[:, t + 1], eps=eps)
            rel_g[:, t] = lrp_linear(i[:, t] * g[:, t], c[:, t],
                                     rel_c[:, t + 1], eps=eps)
            rel_x[:, t] = lrp_linear(x[:, t], g_pre[:, t], rel_g[:, t],
                                     w=w_ig, eps=eps)

            h_prev = np.zeros((batch_size, self.hidden_size)) if t == 0 \
                else h[:, t - 1]
            rel_h[:, t] += lrp_linear(h_prev, g_pre[:, t], rel_g[:, t], w=w_hg,
                                      eps=eps)

        return rel_x


class LRPGRU(LRPRNNMixin, BackpropGRU):
    """
    A GRU module with LRP.
    """

    def _layer_backward(self, rel_y: np.ndarray, layer: int, direction: int,
                        eps: float = 0.001) -> np.ndarray:
        """
        Performs a backward pass using numpy operations for one layer.
        :param rel_y: The relevance flowing to this layer
        :param layer: The layer to perform the backward pass for
        :param direction: The direction to perform the backward pass for
        :return: The relevance of the layer inputs
        """
        if direction == 0:
            x = self._input[layer]
            h, r, z, n, n_pre, w_in, w_hn = self._state["ltr"][layer]
        else:
            x = np.flip(self._input[layer], 1)
            h, r, z, n, n_pre, w_in, w_hn = self._state["rtl"][layer]

        batch_size, seq_len, _ = x.shape

        # Initialize
        rel_h = np.zeros((batch_size, seq_len + 1, self.hidden_size))
        rel_n = np.zeros(n.shape)
        rel_x = np.zeros(x.shape)

        # Backward pass
        rel_h[:, 1:] = rel_y
        for t in reversed(range(seq_len)):
            rel_h[:, t] = lrp_linear(z[:, t] * n[:, t], h[:, t],
                                     rel_h[:, t + 1], eps=eps)
            rel_n[:, t] = lrp_linear((1 - z[:, t]) * n[:, t], h[:, t],
                                     rel_h[:, t + 1], eps=eps)
            rel_x[:, t] = lrp_linear(x[:, t], n_pre[:, t], rel_n[:, t],
                                     w=w_in, eps=eps)

            h_prev = np.zeros((batch_size, self.hidden_size)) if t == 0 \
                else h[:, t - 1]
            rel_h[:, t] += lrp_linear(h_prev, n_pre[:, t], rel_n[:, t],
                                      w=r[:, t] * w_hn, eps=eps)

        return rel_x


class LRPLayerNorm(BackpropLayerNorm):
    """
    A LayerNorm module with LRP.
    """

    def attr_backward(self, rel_y: np.ndarray, eps: float = 0.001) -> \
            np.ndarray:
        """
        :param rel_y:
        :param eps:
        :return:
        """
        if self.elementwise_affine:
            rel_y = lrp_linear(self._state["gamma_term"],
                               self._state["output"], rel_y, eps=eps)

        num = self._state["x"] - self._state["mean"]
        rel_x = lrp_linear(self._state["x"], num, rel_y, eps=eps)
        rel_mean = lrp_linear(-self._state["mean"], num, rel_y, eps=eps)

        n = reduce(operator.mul, self.normalized_shape, 1)
        rel_x += lrp_linear(self._state["x"], n * self._state["mean"],
                            rel_mean, eps=eps)

        return rel_x

In [None]:
from typing import Tuple

import numpy as np
from torch import nn
from transformers.models.bert import modeling_bert as bert




class LRPBertMixin(BackpropBertMixin):
    _layer_types = {nn.Linear: LRPLinear,
                    nn.LayerNorm: LRPLayerNorm}


class LRPBertEmbeddings(LRPBertMixin, BackpropBertEmbeddings):
    """
    BertEmbeddings with LRP.
    """

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            Tuple[HiddenArray, HiddenArray, HiddenArray]:
        """
        :param rel_y:
        :param eps:
        :return:
        """
        rel_y = self.LayerNorm.attr_backward(rel_y, eps=eps)

        inp_embeds, pos_embeds, tok_type_embeds = self._state
        combined_embeds = inp_embeds + pos_embeds + tok_type_embeds
        rel_input = lrp_linear(inp_embeds, combined_embeds, rel_y, eps=eps)
        rel_pos = lrp_linear(pos_embeds, combined_embeds, rel_y, eps=eps)
        rel_tok = lrp_linear(tok_type_embeds, combined_embeds, rel_y, eps=eps)
        return rel_input, rel_pos, rel_tok


class LRPBertSelfAttention(LRPBertMixin, BackpropBertSelfAttention):
    """
    BertSelfAttention with LRP.
    """

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            HiddenArray:
        """
        All relevance gets propagated to the value layer.
        :param rel_y:
        :param eps:
        :return:
        """
        rel_value_layer = lrp_matmul(self._state["value_layer"],
                                     self._state["attention_probs"],
                                     self._state["context_layer"],
                                     self.hidden_to_attention(rel_y),
                                     eps=eps)

        rel_value_layer = self.attention_to_hidden(rel_value_layer)
        return self.value.attr_backward(rel_value_layer)


class LRPBertSelfOutput(LRPBertMixin, BackpropBertSelfOutput):
    """
    BertSelfOutput with LRP.
    """

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            Tuple[HiddenArray, HiddenArray]:
        input_tensor = self._state["input_tensor"]
        dense_output = self._state["dense_output"]
        pre_layer_norm = input_tensor + dense_output

        rel_pre_layer_norm = self.LayerNorm.attr_backward(rel_y)
        rel_input_tensor = lrp_linear(input_tensor, pre_layer_norm,
                                      rel_pre_layer_norm, eps=eps)
        rel_dense_output = lrp_linear(dense_output, pre_layer_norm,
                                      rel_pre_layer_norm, eps=eps)
        rel_hidden_states = self.dense.attr_backward(rel_dense_output)

        return rel_hidden_states, rel_input_tensor


class LRPBertAttention(LRPBertMixin, BackpropBertAttention):
    """
    BertAttention with LRP.
    """
    _bert_layer_types = {bert.BertSelfAttention: LRPBertSelfAttention,
                         bert.BertSelfOutput: LRPBertSelfOutput}

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            HiddenArray:
        rel_hidden, rel_input = self.output.attr_backward(rel_y, eps=eps)
        rel_input += self.self.attr_backward(rel_hidden, eps=eps)
        return rel_input


class LRPBertIntermediate(LRPBertMixin, BackpropBertIntermediate):
    """
    BertIntermediate with LRP.
    """

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            HiddenArray:
        return self.dense.attr_backward(rel_y, eps=eps)


class LRPBertOutput(LRPBertMixin, BackpropBertOutput):
    """
    BertOutput with LRP.
    """

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            Tuple[HiddenArray, HiddenArray]:
        input_tensor = self._state["input_tensor"]
        dense_output = self._state["dense_output"]
        pre_layer_norm = input_tensor + dense_output

        rel_pre_layer_norm = self.LayerNorm.attr_backward(rel_y)
        rel_input_tensor = lrp_linear(input_tensor, pre_layer_norm,
                                      rel_pre_layer_norm, eps=eps)
        rel_dense_output = lrp_linear(dense_output, pre_layer_norm,
                                      rel_pre_layer_norm, eps=eps)
        rel_hidden_states = self.dense.attr_backward(rel_dense_output)

        return rel_hidden_states, rel_input_tensor


class LRPBertLayer(BackpropBertLayer):
    """
    BertLayer with LRP.
    """
    _bert_layer_types = {bert.BertAttention: LRPBertAttention,
                         bert.BertIntermediate: LRPBertIntermediate,
                         bert.BertOutput: LRPBertOutput}

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            HiddenArray:
        rel_intermediate, rel_attn = self.output.attr_backward(rel_y, eps=eps)
        rel_attn += self.intermediate.attr_backward(rel_intermediate, eps=eps)

        if self._state["crossattention_used"]:
            rel_attn = self.crossattention.attr_backward(rel_attn, eps=eps)
        return self.attention.attr_backward(rel_attn, eps=eps)


class LRPBertEncoder(BackpropBertEncoder):
    """
    BertEncoder with LRP.
    """
    _bert_layer_types = {bert.BertLayer: LRPBertLayer}

    def attr_backward(self, rel_y: HiddenArray, eps: float = 0.001) -> \
            HiddenArray:
        rel = rel_y
        for e in reversed(self.layer):
            rel = e.attr_backward(rel, eps=eps)
        return rel


class LRPBertPooler(LRPBertMixin, BackpropBertPooler):
    """
    BertPooler with LRP.
    """

    def attr_backward(self, rel_y: np.ndarray, eps: float = 0.001) -> \
            HiddenArray:
        return self.dense.attr_backward(rel_y, eps=eps)


class LRPBertModel(BackpropBertModel):
    """
    BertModel with LRP.
    """
    _bert_layer_types = {bert.BertEmbeddings: LRPBertEmbeddings,
                         bert.BertEncoder: LRPBertEncoder,
                         bert.BertPooler: LRPBertPooler}

    def attr_backward(self, rel_sequence: HiddenArray = None,
                      rel_pooled: np.ndarray = None, eps: float = 0.001) -> \
            Tuple[HiddenArray, HiddenArray, HiddenArray]:
        assert rel_sequence is not None or rel_pooled is not None

        if rel_sequence is None:
            rel_sequence = np.zeros(self._state["output_shape"])
        if rel_pooled is not None:
            rel_first = self.pooler.attr_backward(rel_pooled, eps=eps)
            rel_sequence[:, 0] += rel_first

        rel_embeddings = self.encoder.attr_backward(rel_sequence, eps=eps)
        return self.embeddings.attr_backward(rel_embeddings) + \
               (rel_embeddings,)


BBFSC = BackpropBertForSequenceClassification


class LRPBertForSequenceClassification(LRPBertMixin, BBFSC):
    """
    A BERT model with a linear decoder.
    """
    _bert_layer_types = {bert.BertModel: LRPBertModel}

    def attr_backward(self, rel_y: np.ndarray, eps: float = 0.001) -> \
            Tuple[HiddenArray, HiddenArray, HiddenArray]:
        rel_pooled = self.classifier.attr_backward(rel_y, eps=eps)
        return self.bert.attr_backward(rel_pooled=rel_pooled, eps=eps)

In [None]:
import numpy as np
import torch
from IPython.core.display import display, HTML
from transformers import BertTokenizer


In [None]:
pip install yattag

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting yattag
  Downloading yattag-1.14.0.tar.gz (26 kB)
Building wheels for collected packages: yattag
  Building wheel for yattag (setup.py) ... [?25l[?25hdone
  Created wheel for yattag: filename=yattag-1.14.0-py3-none-any.whl size=15659 sha256=e5d282b934cf6d7a0e9a95ac96b356b3a6c0e3d5d0398eee57c755c4cc31ec1a
  Stored in directory: /root/.cache/pip/wheels/4d/32/61/f205e276a280e24c3fca996bd956781b2a0fbad498161e53f4
Successfully built yattag
Installing collected packages: yattag
Successfully installed yattag-1.14.0


In [None]:
from typing import Callable, List, Tuple

import matplotlib.pyplot as plt
from yattag import Doc


def _rescale_score_by_abs(score: float, max_score: float,
                          min_score: float) -> float:
    """
    Normalizes an attribution score to the range [0., 1.], where a score
    score of 0. is mapped to 0.5.
    :param score: An attribution score
    :param max_score: The maximum possible attribution score
    :param min_score: The minimum possible attribution score
    :return: The normalized score
    """
    if -1e-5 < min_score and max_score < 1e-5:
        return .5
    elif max_score == min_score and min_score < 0:
        return 0.
    elif max_score == min_score and max_score > 0:
        return 1.

    top = max(abs(max_score), abs(min_score))
    return (score + top) / (2. * top)


def _get_rgb(c_tuple: Tuple[float]) -> str:
    """
    Converts a color from a tuple with values in [0., 1.] to RGB format.
    :param c_tuple: A color
    :return: The color, in RGB format
    """
    return "#%02x%02x%02x" % tuple(int(i * 255.) for i in c_tuple[:3])


def _span_word(tag: Callable, text: Callable, word: str, score: float,
               colormap: Callable):
    """
    Creates an HTML DOM object that contains a word with a background
    color representing its attribution score.
    :param tag: The tag() method from yattag
    :param text: The text() method from yattag
    :param word: A word
    :param score: The word's attribution score
    :param colormap: A matplotlib colormap
    :return: None
    """
    bg = colormap(score)
    style = "color:" + _get_rgb(bg) + ";font-weight:bold;background-color: " \
                                      "#ffffff;padding-top: 15px;" \
                                      "padding-bottom: 15px;"
    with tag("span", style=style):
        text(" " + word + " ")
    text(" ")


def html_heatmap(tokens: List[str], scores: List[float],
                 cmap_name: str = "coolwarm") -> str:
    """
    Constructs a word-level heatmap in HTML format.
    :param tokens: A sequence of tokens
    :param scores: The attribution score assigned to each token
    :param cmap_name: A matplotlib diverging colormap
    :return: The heatmap, as HTML code
    """
    colormap = plt.get_cmap(cmap_name)

    assert len(tokens) == len(scores)
    max_s = max(scores)
    min_s = min(scores)

    doc, tag, text = Doc().tagtext()

    for idx, w in enumerate(tokens):
        score = _rescale_score_by_abs(scores[idx], max_s, min_s)
        _span_word(tag, text, w, score, colormap)

    return doc.getvalue()


def latex_heatmap(tokens: List[str], scores: List[float],
                  cmap_name: str = "coolwarm") -> str:
    """
        Constructs a word-level heatmap in LaTeX format.
        :param tokens: A sequence of words
        :param scores: The attribution score assigned to each token
        :param cmap_name: A matplotlib diverging colormap
        :return: The heatmap, as LaTeX code
        """
    colormap = plt.get_cmap(cmap_name)

    assert len(tokens) == len(scores)
    max_s = max(scores)
    min_s = min(scores)

    code = ""
    code_template = "\\textcolor[rgb]{{{},{},{}}}{{\\textbf{{{}}}}} "
    for idx, w in enumerate(tokens):
        score = _rescale_score_by_abs(scores[idx], max_s, min_s)
        r, g, b, _ = colormap(score)
        code += code_template.format(r, g, b, w)

    return code

In [None]:
import pandas as pd


In [None]:
df=pd.read_csv("dataSpecial.csv", encoding="latin", header=[0])
df

Unnamed: 0,Description,Semantic
0,Finnish Talentum reports its operating profit ...,positive
1,"Lifetree was founded in 2000 , and its revenue...",positive
2,Nokia also noted the average selling price of ...,positive
3,Calls to the switchboard and directory service...,negative
4,"Earnings per share EPS are seen at EUR 0.56 , ...",positive
5,The growth of net sales has continued favourab...,positive
6,The company slipped to an operating loss of EU...,negative
7,The company 's profit before taxes fell to EUR...,negative
8,Unit costs for flight operations fell by 6.4 p...,negative
9,"Tiimari , the Finnish retailer , reported to h...",positive


In [None]:
data=df['Description']
labels=df['Semantic']

In [None]:
print(data[0])

Finnish Talentum reports its operating profit increased to EUR 20.5 mn in 2005 from EUR 9.3 mn in 2004 , and net sales totaled EUR 103.3 mn , up from EUR 96.4 mn .


In [None]:
print("Loading model...")
config_path = "bert-sst-config.pt"
state_dict_path = "bert-sst.pt"

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = LRPBertForSequenceClassification(torch.load(config_path))
model.load_state_dict(torch.load(state_dict_path))
model.eval()
print("Done.")

Loading model...
Done.


In [None]:
text = "Finnish Talentum reports its operating profit increased to EUR 20.5 mn in 2005 from EUR 9.3 mn in 2004 , and net sales totaled EUR 103.3 mn , up from EUR 96.4 mn ."

In [None]:
model.eval()
inputs = tokenizer(text, return_tensors="pt")
logits = model(**inputs).logits.squeeze()
    
classes = ["<unk>", "positive", "negative", "neutral"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
<unk>: -3.0469701290130615
positive: 2.0974607467651367
negative: -0.07706346362829208
neutral: 0.8291009664535522


In [None]:
inputs = tokenizer(text, return_tensors="pt")
model.attr()
output = model(**inputs)

print("Attr Forward Pass Output:")
print(output)



Attr Forward Pass Output:
[[-3.0469708   2.0974607  -0.07706366  0.8291011 ]]


In [None]:
tokens = tokenizer.tokenize(text)
rel_y = np.zeros(output.shape)
print(rel_y)
rel_y[:, 1] = output[:, 1]
print(rel_y[:, 1])
print(output[:, 1])
print(rel_y)
rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

[[0. 0. 0. 0.]]
[2.09746075]
[2.0974607]
[[0.         2.09746075 0.         0.        ]]
LRP Scores:
finnish: -0.03818813363897409
talent: 0.037735748815923176
##um: -0.007421630797781231
reports: -0.006879173907813348
its: -0.017306320650560904
operating: 0.022943802915842713
profit: 0.01082073476918048
increased: -0.001143817423679491
to: 0.003650053711893085
eu: 0.011351084231665053
##r: 0.019752153379993465
20: -0.002586911065000123
.: 0.0035768710604912645
5: -0.0008076440723636111
mn: 0.004068728135263494
in: -0.0020539840626747313
2005: -0.011911999336193538
from: 0.007878153728688827
eu: 0.012860429624884756
##r: -0.002699808758882861
9: -0.00482746921741147
.: 0.0047750359956760385
3: -0.006139836594876511
mn: 0.00107325293943088
in: -0.005613749823510426
2004: 0.00044051247825661844
,: -0.003091936056496245
and: 0.0045946914110006135
net: 0.018538889583819204
sales: 0.024043416978601316
totaled: 0.006257266955498782
eu: 0.013419341292337357
##r: -0.004484486450248701
103: 0.0

Relevance of positional embeddings:


Relevance of type embeddings:


Relevance of combined embeddings:


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer2 = AutoTokenizer.from_pretrained("ProsusAI/finbert")
model2 = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")

Downloading:   0%|          | 0.00/252 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/758 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

In [None]:
model2.eval()
inputs = tokenizer2(text, return_tensors="pt")
logits = model2(**inputs).logits.squeeze()
    
classes = ["positive", "negative", "neutral", "unkown"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
positive: 2.0532076358795166
negative: -1.6006394624710083
neutral: -1.684561848640442


In [None]:
inputs = tokenizer2(text, return_tensors="pt")
output = model2(**inputs)

print("Attr Forward Pass Output:")

x=output[0]
a=x.tolist()
# print(a[0][0])
b=a[0][0]
print(b)
# b = ''.join(str(a).split(','))
# print(b)


Attr Forward Pass Output:
2.0532076358795166


In [None]:
print(b)

2.0532076358795166


In [None]:
tokens = tokenizer.tokenize(text)
rel_y =np.zeros((1, 4))
print(rel_y)
rel_y[0][0]=2.059354782104492
print(rel_y)

rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))
    
print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

[[0. 0. 0. 0.]]
[[2.05935478 0.         0.         0.        ]]
Logit Scores:
<unk>: -3.0469701290130615
positive: 2.0974607467651367
negative: -0.07706346362829208
neutral: 0.8291009664535522
LRP Scores:
finnish: -0.015083151810027392
talent: 0.03046552438389919
##um: -0.0029068948274464253
reports: 0.003365928862888454
its: -0.005139797694264903
operating: -0.005980635297764278
profit: -0.004485777702605634
increased: 0.018507118181388983
to: 0.005815585790140922
eu: -0.0010005367884556612
##r: 0.005103287285074233
20: -0.00046220792137128674
.: 0.003971829360427951
5: -0.0008252270056323931
mn: -0.003060380487835141
in: -0.0022266762251118392
2005: 0.0032016751513549643
from: 0.0009107376297346037
eu: 0.01153042527997719
##r: 0.0029490487526264604
9: -0.003730070297442852
.: 0.002197356588827759
3: -0.0036560728267331717
mn: 0.0017351964691177375
in: 0.008917203876010981
2004: -0.007638589634741528
,: -0.0020020730141312513
and: 0.009650894814915725
net: 0.0283564024130561
sales: 0.

Relevance of positional embeddings:


Relevance of type embeddings:


Relevance of combined embeddings:


In [None]:
print("Loading model...")
config_path = "bert-sst-config.pt"
state_dict_path = "bert-sst.pt"

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = LRPBertForSequenceClassification(torch.load(config_path))
model.load_state_dict(torch.load(state_dict_path))
model.eval()
print("Done.")

Loading model...
Done.


In [None]:
text = "The levelled of net sales has continued favourably in the Middle East and Africaand in Asia Pacific ."

In [None]:
model.eval()
inputs = tokenizer(text, return_tensors="pt")
logits = model(**inputs).logits.squeeze()
    
classes = ["<unk>", "positive", "negative", "neutral"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
<unk>: -3.553102731704712
positive: 1.905517578125
negative: -0.4710124433040619
neutral: 1.100406527519226


In [None]:
inputs = tokenizer(text, return_tensors="pt")
model.attr()
output = model(**inputs)

print("Attr Forward Pass Output:")
print(output)

Attr Forward Pass Output:
[[-3.5531042   1.905518   -0.47101274  1.1004071 ]]




In [None]:
tokens = tokenizer.tokenize(text)
rel_y = np.zeros(output.shape)
print(rel_y)
rel_y[:, 1] = output[:, 1]
print(rel_y[:, 1])
print(output[:, 1])
print(rel_y)
rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

[[0. 0. 0. 0.]]
[1.90551805]
[1.905518]
[[0.         1.90551805 0.         0.        ]]
Logit Scores:
<unk>: -3.553102731704712
positive: 1.905517578125
negative: -0.4710124433040619
neutral: 1.100406527519226
LRP Scores:
the: 0.006262285116170857
level: -0.03606136658233169
##led: -0.10528512439599894
of: 0.009089939681924095
net: 0.007485208198103687
sales: 0.07631732766832283
has: 0.04385774535661372
continued: 0.0022478593556794833
favour: 0.12837076449977028
##ably: 0.005304890502708113
in: 0.0051698459659604065
the: 0.0026148309302867828
middle: 0.00830559617858398
east: 0.0016733791304399692
and: 0.0055035151522878824
africa: -0.0034349870663595333
##and: -0.01943764243553721
in: -0.00024433210582298307
asia: -0.020680236534181082
pacific: 0.0031371719980436835
.: 0.006264863934069245
Relevance of word embeddings:


Relevance of positional embeddings:


Relevance of type embeddings:


Relevance of combined embeddings:


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer2 = AutoTokenizer.from_pretrained("ProsusAI/finbert")
model2 = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")

In [None]:
model2.eval()
inputs = tokenizer2(text, return_tensors="pt")
logits = model2(**inputs).logits.squeeze()
    
classes = ["positive", "negative", "neutral", "unkown"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
positive: 1.9893385171890259
negative: -2.4561164379119873
neutral: -0.659746527671814


In [None]:
inputs = tokenizer2(text, return_tensors="pt")
output = model2(**inputs)

print("Attr Forward Pass Output:")

x=output[0]
a=x.tolist()
# print(a[0][0])
b=a[0][0]
print(b)
# b = ''.join(str(a).split(','))
# print(b)

Attr Forward Pass Output:
1.9893385171890259


In [None]:
tokens = tokenizer.tokenize(text)
rel_y =np.zeros((1, 4))
print(rel_y)
rel_y[0][0]=1.9893385171890259
print(rel_y)

rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

[[0. 0. 0. 0.]]
[[1.98933852 0.         0.         0.        ]]
LRP Scores:
with: 0.0018042054998812615
the: 0.000827071529022734
new: -0.008090928654314055
production: 0.019673007116687943
plant: 0.012546786687276917
the: 0.0019214967842660826
company: 0.013568262373840068
would: -0.04177473873757129
adjust: 0.016617726368736093
its: 0.01068764379592845
capacity: 0.021491058411293344
to: 0.002796822239700575
meet: 0.017430512721591747
the: 0.012407440759703743
expected: -0.025160071420322505
adjust: 0.03132264910102439
in: 0.030133524799463993
demand: -0.00783158479535886
and: 0.06091531084671442
would: -0.1261613901507531
improve: -0.009725580489688322
the: 0.00594131137782335
use: 0.018605695776914098
of: -0.004274377232101477
raw: 0.003765317150039936
materials: -0.0011455171827551265
and: 0.028426105374669698
therefore: 0.036321726340894374
adjust: 0.020225269642584098
the: 0.0028652346196266533
production: 0.013833323666016032
profit: 0.022166813168972645
##ability: -0.0218434999

Relevance of positional embeddings:


Relevance of type embeddings:


Relevance of combined embeddings:


In [None]:
print("Loading model...")
config_path = "bert-sst-config.pt"
state_dict_path = "bert-sst.pt"

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = LRPBertForSequenceClassification(torch.load(config_path))
model.load_state_dict(torch.load(state_dict_path))
model.eval()
print("Done.")

Loading model...
Done.


In [None]:
text = "With the new production plant the company would decrease its capacity to meet the expected decrease in demand and would improve the use of raw materials and therefore decrease the production profitability."

In [None]:
model.eval()
inputs = tokenizer(text, return_tensors="pt")
logits = model(**inputs).logits.squeeze()
    
classes = ["<unk>", "positive", "negative", "neutral"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
<unk>: -2.6713459491729736
positive: 0.9979174733161926
negative: 1.7519034147262573
neutral: 1.0688958168029785


In [None]:
inputs = tokenizer(text, return_tensors="pt")
model.attr()
output = model(**inputs)

print("Attr Forward Pass Output:")
print(output)



Attr Forward Pass Output:
[[-2.6713464  0.9979184  1.7519033  1.0688958]]


In [None]:
tokens = tokenizer.tokenize(text)
rel_y = np.zeros(output.shape)
print(rel_y)
rel_y[:, 2] = output[:, 2]
print(rel_y[:, 2])
print(output[:, 2])
print(rel_y)
rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer2 = AutoTokenizer.from_pretrained("ProsusAI/finbert")
model2 = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")

In [None]:
model2.eval()
inputs = tokenizer2(text, return_tensors="pt")
logits = model2(**inputs).logits.squeeze()
    
classes = ["positive", "negative", "neutral", "unkown"]
print("Logit Scores:")
for c, score in zip(classes, logits):
    print("{}: {}".format(c, score))

Logit Scores:
positive: -1.600598931312561
negative: 2.890672445297241
neutral: -0.977862536907196


In [None]:
inputs = tokenizer2(text, return_tensors="pt")
output = model2(**inputs)

print("Attr Forward Pass Output:")

x=output[0]
a=x.tolist()
# print(a[0][0])
b=a[0][1]
print(b)
# b = ''.join(str(a).split(','))
# print(b)

Attr Forward Pass Output:
2.890672445297241


In [None]:
tokens = tokenizer.tokenize(text)
rel_y =np.zeros((1, 4))
print(rel_y)
rel_y[0][1]=2.890672445297241
print(rel_y)

rel_word, rel_pos, rel_type, rel_embed = model.attr_backward(rel_y, eps=.1)
rel_word = np.sum(rel_word[0, 1:-1], -1)
rel_pos = np.sum(rel_pos[0, 1:-1], -1)
rel_type = np.sum(rel_type[0, 1:-1], -1)
rel_embed = np.sum(rel_embed[0, 1:-1], -1)

print("LRP Scores:")
for t, s in zip(tokens, rel_embed):
    print(t, s, sep=": ")
    
print("Relevance of word embeddings:")
display(HTML(html_heatmap(tokens, list(rel_word))))

print("Relevance of positional embeddings:")
display(HTML(html_heatmap(tokens, list(rel_pos))))

print("Relevance of type embeddings:")
display(HTML(html_heatmap(tokens, list(rel_type))))

print("Relevance of combined embeddings:")
display(HTML(html_heatmap(tokens, list(rel_embed))))

[[0. 0. 0. 0.]]
[[0.         2.89067245 0.         0.        ]]
LRP Scores:
with: -0.010166933252796162
the: 0.5908704837557412
new: -0.28629376649299604
production: -0.0021112310116838985
plant: 0.2513508102559735
the: -0.12878531946571375
company: -0.2982337158513682
would: -0.8719197382313835
decrease: 0.2885911990020068
its: -0.11226109990247209
capacity: 0.40760836365467695
to: 0.6697607072633525
meet: -0.3520884529026946
the: -0.4104387791682438
expected: -0.41802376349583026
decrease: -0.10974260673639072
in: 0.14338384453397596
demand: -0.44970133760914366
and: -0.5429633429463552
would: -0.4372991912064631
improve: -0.3992668403824726
the: -0.17786212561857578
use: -0.01763908232789596
of: -0.031707650648228614
raw: -0.22408171865221838
materials: 0.01578401648435413
and: 0.06669214092039565
therefore: -0.445459606114455
decrease: -0.18787919589229052
the: -0.3639356058013059
production: -0.2629642260334204
profit: -0.11900865849706907
##ability: -0.3207311325510718
.: -0.2009

Relevance of positional embeddings:


Relevance of type embeddings:


Relevance of combined embeddings:
