In [None]:
from model import CrossAttOmics
import torch 
from einops.layers.torch import Rearrange, EinMix
from functools import partial
from einops import rearrange
from torch.nn import LayerNorm, BatchNorm1d, Identity, Module
from captum.attr._utils.lrp_rules import EpsilonRule, IdentityRule
from captum.attr import LayerLRP, LRP

import matplotlib.pyplot as plt
import seaborn as sn 
import seaborn.objects as so
import pandas as pd

class CrossAttOmicsCaptumModel(Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, mRNA_coding, mRNA_non_coding, DNAm_genes, CNV, miRNA, Protein):
        enc_mRNA_coding = self.model.modalities_encoders["mRNA_coding"](mRNA_coding).embedding
        enc_mRNA_non_coding = self.model.modalities_encoders["mRNA_non_coding"](mRNA_non_coding).embedding
        enc_DNAm_genes = self.model.modalities_encoders["DNAm_genes"](DNAm_genes).embedding
        enc_CNV = self.model.modalities_encoders["CNV"](CNV).embedding
        enc_miRNA = self.model.modalities_encoders["miRNA"](miRNA).embedding
        enc_Protein = self.model.modalities_encoders["Protein"](Protein).embedding
        
        x_fused = self.model.fusion({"mRNA_coding": enc_mRNA_coding, 
                                        "mRNA_non_coding": enc_mRNA_non_coding, 
                                        "DNAm_genes": enc_DNAm_genes, 
                                        "CNV": enc_CNV, 
                                        "miRNA": enc_miRNA, 
                                        "Protein": enc_Protein}).embedding 
        x_multimodal = self.model.multimodal_encoder(x_fused).embedding
        return self.model.classifier(x_multimodal).cancer_type

# Define LRP rules to use with the different layers.
def set_custom_rules(model):
    # add rule for unsupported layers
    for module in model.modules():
        if isinstance(module, LayerNorm):
            module.rule = EpsilonRule()
        if isinstance(module, BatchNorm1d):
            module.rule = EpsilonRule()
        if isinstance(module, Identity):
            module.rule = IdentityRule()
        if isinstance(module, EinMix): # similar to a LinearLayer
            module.rule = EpsilonRule()
            

def get_test_data():
    return 

In [None]:
ckpt_path = ""
model = CrossAttOmics.load_from_checkpoint(ckpt_path, map_location="cpu").to("cuda:0")

# Convert all rearrange Layer to use the functionnal form
# These layers are used multiple times (incompatible with captum)
# but they do not perform any computation
rearrange_dict = {name: module for name, module in model.named_modules() if isinstance(module, Rearrange)}
for m_path, module in rearrange_dict.items():
    base = model
    for name in m_path.split(".")[:-1]:
        base = getattr(base, name)
    delattr(base, m_path.split(".")[-1])
    setattr(base, m_path.split(".")[-1], partial(rearrange, pattern = module.pattern, **module.axes_lengths))
    
set_custom_rules(model)
# Create a model compatible with Captum, ie accepts tuple as inputs and not dict        
CaptumModel = CrossAttOmicsCaptumModel(model)
CaptumModel.eval()
batch, le = get_test_data(ckpt_path)

In [None]:
# layers of interest
add_single_omics = True 
for name, module in CaptumModel.model.named_children():
    if name == "fusion_fn":
        layers = [m.norm for m in module.cross_layers.values()]
        layer_name = [m for m in module.cross_layers.keys()]
if add_single_omics:
    for omics in CaptumModel.model.fusion_fn.single_omics:
        layers.append(CaptumModel.model.modalities_encoders[omics].attOmics_layers[-1].interaction.norm_layer)
        layer_name.append(omics)

attribution = LayerLRP(CaptumModel, layers)
lrp_ca = attribution.attribute((batch["x"]["mRNA_coding"], 
                                batch["x"]["mRNA_non_coding"], 
                                batch["x"]["DNAm_genes"], 
                                batch["x"]["CNV"], 
                                batch["x"]["miRNA"], 
                                batch["x"]["Protein"]), 
                               target=batch["y"]["cancer_type"])

# per cancer 
label = batch["y"]["cancer_type"]
unique = torch.unique(label)
res = {"cancer": le.inverse_transform(unique.cpu().numpy())}
for i, name in enumerate(layer_name):
    res[name] = [lrp_ca[i][label == j].detach().sum(dim=(1,2)).mean().cpu().numpy().item() for j in unique]
res = pd.DataFrame.from_dict(res).melt(id_vars="cancer")


In [None]:
res["cancer"] = res["cancer"].str.replace("TCGA-", "")
res["variable"] = res["variable"].str.replace("-_-", "→").str.replace("DNAm_genes", "DNAm").str.replace("mRNA_coding", "mRNA").str.replace("mRNA_non_coding", "nc mRNA")
cancers = res["cancer"].unique().reshape(-1,9)

In [None]:
fig, axes = plt.subplots(1,2, sharex=True)
fig.subplots_adjust(wspace=0.3)
legend=True
fig.supxlabel("LRP relevance scores")
for i, cancer in enumerate(cancers):
    data = res[res["cancer"].isin(cancer)]
    ax = axes[i]
    ax.spines[["top", "left", "right"]].set_visible(False)
    ax.tick_params("y", left=False, pad=0)
    (
        so.Plot(data=data, x="value", y="cancer", color="variable")
         .add(so.Bar(), so.Dodge(), legend=legend)
         .layout(size=(7.00787, 4.33))
         .label(x=None, y=None, color="Interaction")
         .on(ax)
         .plot()
    )
    legend=False
sn.move_legend(
    fig, "lower center",
    bbox_to_anchor=(axes[0].get_position().x0, -0.15, axes[1].get_position().x0 + axes[1].get_position().width - axes[0].get_position().x0, 0.08), ncol=5, frameon=False,
)