In [1]:
from timm.models.vision_transformer import VisionTransformer
from torchvision.io import read_image
from torchvision import transforms
import torch.nn.functional as F
from torch.fx import symbolic_trace
from torch.autograd import Function
import requests
from PIL import Image
from io import BytesIO
import torch
import timm
import json
import urllib.request
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
m = model.eval()

In [3]:
# Download a sample image
url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"
image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    ),
])
data = transform(image).unsqueeze(0)

In [5]:
with torch.no_grad():
    y = model(data)

In [6]:
def conservation_check(func):
    #TODO: bug in add2_fn
    def wrapped(ctx, *out_relevance):

        inp_relevance = func(ctx, *out_relevance)

        if CONSERVATION_CHECK_FLAG[0]:

            out_rel_sum = sum(r.float().sum() for r in out_relevance if r is not None)
            inp_elements = sum(r.float().sum() for r in inp_relevance if r is not None)
            if CONSERVATION_CHECK_FLAG[1]:
                print(func.__name__, out_rel_sum, inp_elements)
            # inp_rel_mean = out_rel_sum/inp_elements

            # if torch.isnan(inp_rel_mean).any():
            #     raise ValueError(f"NaN at {func}")
            # if REPLACE_WITH_MEAN:
            #     inp_relevance = tuple(torch.full(r.shape, inp_rel_mean).to(r.device) if r is not None else None for r in inp_relevance)


        return inp_relevance
        
    return wrapped

In [7]:
import torch
import torch.nn.functional as F


class epsilon_lrp_fn:
    def __init__(self, fn, epsilon):
        self.fn = fn
        self.epsilon = epsilon
        self.requires_grads = None
        self.inputs = None
        self.outputs = None

    def forward(self, *inputs):
        self.requires_grads = [inp.requires_grad for inp in inputs]
        if not any(self.requires_grads):
            return self.fn(*inputs)

        self.inputs = tuple(inp.detach().requires_grad_() if inp.requires_grad else inp for inp in inputs)
        with torch.enable_grad():
            self.outputs = self.fn(*self.inputs)
        return self.outputs.detach()

    @conservation_check
    def backward(self, *out_relevance):
        inputs = [self.inputs[i] for i, req in enumerate(self.requires_grads) if req]
        outputs = self.outputs
        relevance_norm = out_relevance[0] / (outputs + self.epsilon)
        grads = torch.autograd.grad(outputs, inputs, relevance_norm)
        relevance = iter([grads[i].mul_(inputs[i]) for i in range(len(inputs))])
        return tuple(next(relevance) if req else None for req in self.requires_grads)


class identity_fn:
    def __init__(self, fn):
        self.fn = fn

    def forward(self, input):
        self.output = self.fn(input)
        return self.output

    @conservation_check
    def backward(self, *out_relevance):
        return out_relevance


class softmax_fn:
    def __init__(self, dim, temprature=1.0, dtype=None, inplace=False):
        self.dim = dim
        self.temprature = temprature
        self.dtype = dtype
        self.inplace = inplace

    def forward(self, inputs):
        if self.dtype is not None:
            inputs = inputs.to(self.dtype)
        inputs = inputs / self.temprature
        outputs = F.softmax(inputs, dim=self.dim, dtype=self.dtype)
        self.inputs, self.outputs = inputs, outputs
        return outputs

    @conservation_check
    def backward(self, *R_out):
        R_out = R_out[0]
        inputs, outputs = self.inputs, self.outputs
        inputs = torch.where(torch.isneginf(inputs), torch.tensor(0).to(inputs), inputs)
        if self.inplace:
            R_in = (R_out.sub_(outputs.mul_(R_out.sum(-1, keepdim=True)))).mul_(inputs)
        else:
            R_in = inputs * (R_out - outputs * R_out.sum(-1, keepdim=True))
        return R_in


class linear_fn:
    def __init__(self, epsilon=1e-8):
        self.epsilon = epsilon

    def forward(self, inputs, weight, bias):
        self.inputs, self.weight = inputs, weight
        self.outputs = F.linear(inputs, weight, bias)
        return self.outputs

    @conservation_check
    def backward(self, *R_out):
        R_out = R_out[0]
        S = R_out / (self.outputs + self.epsilon)
        R_in = torch.matmul(S, self.weight) * self.inputs
        return R_in


class matmul_fn:
    def __init__(self, epsilon=1e-12, inplace=False):
        self.epsilon = epsilon
        self.inplace = inplace

    def forward(self, input_a, input_b):
        self.input_a, self.input_b = input_a, input_b
        self.output = torch.matmul(input_a, input_b)
        return self.output

    @conservation_check
    def backward(self, *R_out):
        R_out = R_out[0]
        if self.inplace:
            S = R_out.div_(self.output.mul_(2).add_(self.epsilon))
        else:
            S = R_out / ((2 * self.output) + self.epsilon)
        R_ina = torch.matmul(S, self.input_b.T).mul_(self.input_a)
        R_inb = torch.matmul(self.input_a.T, S).mul_(self.input_b)
        return R_ina, R_inb


class add_2_tensors_fn:
    def __init__(self, inplace=False, epsilon=1e-8):
        self.inplace = inplace
        self.epsilon = epsilon

    def forward(self, input_a, input_b):
        self.input_a, self.input_b = input_a, input_b
        self.requires_grads = [i for i, inp in enumerate((input_a, input_b))
                               if isinstance(inp, torch.Tensor) and inp.requires_grad]
        self.outputs = input_a + input_b
        return self.outputs

    @conservation_check
    def backward(self, *R_out):
        if not self.requires_grads:
            return None, None
        R_out = R_out[0]
        denom = (self.input_a + self.input_b + self.epsilon)
        R_ina = (R_out * self.input_a) / denom
        R_inb = (R_out * self.input_b) / denom
        return (R_ina if 0 in self.requires_grads else None,
                R_inb if 1 in self.requires_grads else None)


class mul2_fn:
    def __init__(self, inplace=False):
        self.inplace = inplace

    def forward(self, input_a, input_b):
        self.input_a, self.input_b = input_a, input_b
        self.requires_grads = [i for i, inp in enumerate((input_a, input_b))
                               if isinstance(inp, torch.Tensor) and inp.requires_grad]
        return input_a * input_b

    @conservation_check
    def backward(self, *R_out):
        n = len(self.requires_grads)
        R_out = R_out[0]
        R_in = R_out.div_(n) if self.inplace else R_out / n
        return tuple(R_in if i in self.requires_grads else None for i in range(2))


class layernorm_fn:
    def __init__(self, epsilon=1e-6, var_epsilon=1e-6):
        self.epsilon = epsilon
        self.var_epsilon = var_epsilon

    def forward(self, input, weight, bias):
        with torch.enable_grad():
            mean = input.mean(-1, keepdim=True)
            var = ((input - mean) ** 2).mean(-1, keepdim=True)
            std = (var + self.var_epsilon).sqrt()
            y = (input - mean) / std.detach()
            y = y * weight + bias
            self.input, self.y = input, y
        return y.detach()

    @conservation_check
    def backward(self, *R_out):
        R_out = R_out[0]
        R_norm = R_out / (self.y + self.epsilon)
        grads = torch.autograd.grad(self.y, self.input, R_norm)[0]
        R_in = grads * self.input
        return R_in


class conv_fn:
    def __init__(self, lowest=0., highest=1.):
        self.lowest = lowest
        self.highest = highest

    def forward(self, inputs, module):
        self.module = module
        self.inputs = inputs
        self.output = module(inputs)
        self.stride, self.padding, self.kernel = module.stride, module.padding, module.kernel_size
        self.weight = module.weight
        return self.output

    @conservation_check
    def backward(self, *R_out):
        R = R_out[0]
        stride, padding, kernel = self.stride, self.padding, self.kernel
        activation, Z_O, weight = self.inputs, self.output, self.weight
        output_padding = activation.size(2) - ((R.size(2) - 1) * stride[0] - 2 * padding[0] + kernel[0])
        W_L = torch.clamp(weight, min=0)
        W_H = torch.clamp(weight, max=0)

        L = torch.ones_like(activation) * self.lowest
        H = torch.ones_like(activation) * self.highest
        Z_L = F.conv2d(L, W_L, stride=stride, padding=padding)
        Z_H = F.conv2d(H, W_H, stride=stride, padding=padding)
        Z = Z_O - Z_L - Z_H + 1e-9
        S = R / Z

        C_O = F.conv_transpose2d(S, weight, stride=stride, padding=padding, output_padding=output_padding)
        C_L = F.conv_transpose2d(S, W_L, stride=stride, padding=padding, output_padding=output_padding)
        C_H = F.conv_transpose2d(S, W_H, stride=stride, padding=padding, output_padding=output_padding)
        R_in = activation * C_O - L * C_L - H * C_H
        return R_in


In [8]:
def softmax(inputs, dim, temprature=1.0, dtype=None, inplace=False, **kwargs):
    return softmax_fn(dim=dim, temprature=temprature, dtype=dtype, inplace=inplace).forward(inputs)


def linear(fn, inputs, epsilon=1e-6, **kwargs):
    return linear_fn(epsilon=epsilon).forward(inputs, fn.weight, fn.bias)


def matmul(input_a, input_b, epsilon=1e-12, inplace=False, **kwargs):
    return matmul_fn(epsilon=epsilon, inplace=inplace).forward(input_a, input_b)


def add(input_a, input_b, inplace=False, epsilon=1e-8, **kwargs):
    return add_2_tensors_fn(inplace=inplace, epsilon=epsilon).forward(input_a, input_b)


def mul(input_a, input_b, inplace=False, **kwargs):
    return mul2_fn(inplace=inplace).forward(input_a, input_b)


def layernorm(fn, inputs, epsilon=1e-6, var_epsilon=1e-6, **kwargs):
    return layernorm_fn(epsilon=epsilon, var_epsilon=var_epsilon).forward(inputs, fn.weight, fn.bias)


def identity(fn, inputs, **kwargs):
    return identity_fn(fn).forward(inputs)


def epsilon_lrp(fn, epsilon, *inputs, **kwargs):
    return epsilon_lrp_fn(fn, epsilon).forward(*inputs)


def conv_2d(fn, inputs, **kwargs):
    return conv_fn().forward(inputs, fn)

def multihead_attn_fn_cp(fn,
        x,
        num_heads: int = 12,
        attn_mask=None,
        qk_norm=False,
        scale_norm=False, **kwargs):
        qkv, q_norm, k_norm, attn_drop, norm, proj, proj_drop, dim = fn.qkv, fn.q_norm, fn.k_norm, fn.attn_drop, fn.norm, fn.proj, fn.proj_drop, int(x.shape[-1])
        assert dim%num_heads==0, 'dim should be divisible by num_heads'
        if qk_norm or scale_norm:
            assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
        head_dim = dim//num_heads
        scale = head_dim ** -0.5
        B, N, C = x.shape
        qkv_ = qkv(x).reshape(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv_.unbind(0)
        q, k = q_norm(q), k_norm(k)
        q = q * scale
        attn = q @ k.transpose(-2, -1)
        attn = attn + attn_mask if attn_mask is not None else attn
        attn = attn.softmax(-1)
        attn = attn_drop(attn)
        x = epsilon_lrp(torch.matmul, 1e-6, attn.detach(), v)
        x = x.transpose(1, 2).reshape(B,N,C)
        x = norm(x)
        x = proj(x)
        x = proj_drop(x)
        return x

In [9]:
def get_module(module, input_layer=False):
    if isinstance(module, torch.nn.Conv2d) and input_layer:
        return conv_2d
    elif isinstance(module, torch.nn.Dropout) or isinstance(module, torch.nn.Identity) or isinstance(module, torch.nn.GELU):
        return identity
    elif isinstance(module, torch.nn.LayerNorm) or isinstance(module, timm.layers.norm.LayerNorm):
        return layernorm
    elif isinstance(module, timm.layers.attention.Attention):
        return multihead_attn_fn_cp
    elif isinstance(module, torch.nn.Linear):
        return linear
    else:
        return identity

In [10]:
def patch_embed_block(input, block):
    proj = get_module(block.proj)
    x = proj(block.proj, input)
    norm = get_module(block.norm)
    x = norm(block.norm, x)
    x = model._pos_embed(x.flatten(2).transpose(1,2))
    return x
def attn_block(input, block):
    norm = get_module(block.norm1)
    attn = get_module(block.attn)
    ls1 = get_module(block.ls1)
    drop_path1 = get_module(block.drop_path1)
    norm2 = get_module(block.norm2)
    fc1 = get_module(block.mlp.fc1)
    act = get_module(block.mlp.act)
    drop1 = get_module(block.mlp.drop1)
    norm3 = get_module(block.mlp.norm)
    fc2 = get_module(block.mlp.fc2)
    drop2 = get_module(block.mlp.drop2)
    ls2 = get_module(block.ls2)
    drop_path2 = get_module(block.drop_path2)
    x =add(input, drop_path1(inputs =ls1(inputs =attn(block.attn,norm(inputs = input, fn = block.norm1)), fn = block.ls1), fn = block.drop_path1))
    x = add(x, drop_path2(inputs = ls2(inputs = drop2(inputs =fc2(block.mlp.fc2, norm3(inputs = drop1(inputs = act(inputs=fc1(block.mlp.fc1, norm2(inputs = x, fn = block.norm2)), fn = block.mlp.act), fn = block.mlp.drop1), fn = block.mlp.norm)), fn = block.mlp.drop2), fn = block.ls2), fn = block.drop_path2))
    return x


In [11]:
x = data

In [12]:
for name, module in model.named_children():
    if isinstance(module, timm.layers.patch_embed.PatchEmbed):
        y = patch_embed_block(x, module)
    elif isinstance(module, torch.nn.Sequential):
        for name, n_module in module.named_children():
            if isinstance(n_module, timm.models.vision_transformer.Block):
                y = attn_block(y, n_module)
        
    else:
        if name == "fc_norm":
            y = model.pool(y)
        func = get_module(module)
        y = func(inputs = y, fn = module)

In [14]:
R = torch.zeros_like(y)
R[0,torch.argmax(y)] = y[0,torch.argmax(y)]

In [15]:
R

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0