In [2]:
import pickle
import tensorboard
from modules.transcripts import Transcripts
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
from collections import Counter
from modules.model import FOMOnet
from modules.lrp import LRP_FOMOnet
from modules.utils import *
import torch
import torch.nn as nn
import torchvision
from torch.nn import functional as F

In [10]:
# Example usage
model = LRP_FOMOnet()
input_data = torch.randn(1, 4, 1000)  # Example input with shape (batch_size, num_channels, sequence_length)

In [15]:
relevance_scores = model.lrp(input_data)

tensor([[[0.6243, 0.6243, 0.6243, 0.4150, 0.4150, 0.5256, 0.5256, 0.7172,
          0.7172, 0.5438, 0.5438, 0.5438, 0.2740, 0.2740, 0.2918, 0.2918,
          0.8202, 0.8202, 0.3225, 0.3225, 0.3225, 0.8965, 0.8965, 0.6363,
          0.6363, 0.6699, 0.6699, 0.7900, 0.7900, 0.7900, 0.3507, 0.3507,
          0.8721, 0.8721, 0.3027, 0.3027, 0.4064, 0.4064, 0.3514, 0.3514,
          0.3514, 0.5414, 0.5414, 0.2146, 0.2146, 0.4702, 0.4702, 0.3132,
          0.3132, 0.3132, 0.3802, 0.3802, 0.3456, 0.3456, 0.3219, 0.3219,
          0.6241, 0.6241, 0.6241, 0.7145, 0.7145, 0.3808, 0.3808, 0.6249,
          0.6249, 0.6648, 0.6648, 0.3393, 0.3393, 0.3393, 0.3644, 0.3644,
          0.5011, 0.5011, 0.3421, 0.3421, 0.5361, 0.5361, 0.5361, 0.3639,
          0.3639, 0.5889, 0.5889, 0.3987, 0.3987, 0.7356, 0.7356, 0.7356,
          0.3575, 0.3575, 0.2537, 0.2537, 0.8715, 0.8715, 0.2844, 0.2844,
          0.8434, 0.8434, 0.8434, 0.3524, 0.3524, 0.3357, 0.3357, 0.5335,
          0.5335, 0.6796, 0.6796, 0.67

RuntimeError: Given groups=1, weight of size [1, 32, 1], expected input[1, 1, 1000] to have 32 channels, but got 1 channels instead

In [11]:
def lrp_backward(relevance, layer):
    relevance = lrp_linear(layer, relevance)
    return relevance

In [12]:
def lrp_linear(layer, relevance):
    layer_out = layer(relevance)
    layer_in = layer.out_channels

    if isinstance(layer, nn.Conv1d):
        layer_weights = layer.weight
        layer_bias = layer.bias
        layer_stride = layer.stride[0]
        layer_padding = layer.padding[0]

        relevance = lrp_conv1d(layer_out, relevance, layer_weights, layer_bias, layer_stride, layer_padding)
    elif isinstance(layer, nn.Linear):
        layer_weights = layer.weight
        layer_bias = layer.bias

        relevance = lrp_linear_layer(layer_out, relevance, layer_weights, layer_bias)

    return relevance

In [13]:
def lrp_conv1d(layer_out, relevance, weights, bias, stride, padding):
    _, _, in_length = relevance.size()
    _, _, out_length = layer_out.size()

    relevance_padded = F.pad(relevance, (padding, padding))
    weights_flipped = torch.flip(weights, dims=[2])

    unfold_relevance = F.unfold(relevance_padded, (weights.size(2),), stride=stride)
    unfold_relevance = unfold_relevance.view(-1, in_length, weights.size(2))

    unfold_relevance *= weights_flipped.unsqueeze(0)
    unfold_relevance = unfold_relevance.sum(dim=2)

    relevance = F.fold(unfold_relevance, (out_length,), (1,), stride=stride)

    if bias is not None:
        relevance += bias.unsqueeze(0).unsqueeze(-1)

    return relevance

In [14]:
def lrp_linear_layer(layer_out, relevance, weights, bias):
    relevance = relevance / (layer_out + 1e-9)  # Add epsilon to avoid division by zero

    relevance = relevance.matmul(weights)
    relevance = relevance.unsqueeze(-1)

    if bias is not None:
        relevance += bias.unsqueeze(0).unsqueeze(-1)

    return relevance