In [1]:
import torch.nn.functional as F
from torch_explain.nn.concepts import ConceptReasoningLayer, IntpLinearLayer, ConceptReasoningLayerMod
import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.mux41(10)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
    x, c, y, test_size=0.33, random_state=42)

embedding_size = 8
concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    te.nn.ConceptEmbedding(10, c.shape[1], embedding_size),
)


In [2]:

# -------------------------------------#

y_train = F.one_hot(y_train.long().ravel()).float()
y_test = F.one_hot(y_test.long().ravel()).float()

In [3]:
y_train.shape

torch.Size([6, 2])

In [4]:
concept_encoder(x_train)[0].shape

torch.Size([6, 6, 8])

In [5]:

task_predictor = ConceptReasoningLayerMod(embedding_size, y_train.shape[1], log=True)
model = torch.nn.Sequential(concept_encoder, task_predictor)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
# loss_form = torch.nn.BCELoss()
loss_form = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(1):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb, c_pred)
    

    # compute loss
    concept_loss = loss_form(c_pred, c_train)
    task_loss = loss_form(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

local_explanations = task_predictor.explain(c_emb, c_pred, 'local')
global_explanations = task_predictor.explain(c_emb, c_pred, 'global')

# print(local_explanations)
print(global_explanations)


2024-06-14 17:39:49,453 - torch_explain.nn.concepts - INFO - Values: torch.Size([6, 6, 2])
2024-06-14 17:39:49,458 - torch_explain.nn.concepts - INFO - Values: tensor([[[0.5968, 0.5968],
         [0.5289, 0.5289],
         [0.5754, 0.5754],
         [0.5604, 0.5604],
         [0.5827, 0.5827],
         [0.5707, 0.5707]],

        [[0.6055, 0.6055],
         [0.5255, 0.5255],
         [0.5803, 0.5803],
         [0.5367, 0.5367],
         [0.5871, 0.5871],
         [0.5590, 0.5590]],

        [[0.5965, 0.5965],
         [0.5269, 0.5269],
         [0.5693, 0.5693],
         [0.5499, 0.5499],
         [0.5824, 0.5824],
         [0.5638, 0.5638]],

        [[0.6023, 0.6023],
         [0.5265, 0.5265],
         [0.5697, 0.5697],
         [0.5548, 0.5548],
         [0.5829, 0.5829],
         [0.5632, 0.5632]],

        [[0.6021, 0.6021],
         [0.5236, 0.5236],
         [0.5839, 0.5839],
         [0.5433, 0.5433],
         [0.5778, 0.5778],
         [0.5609, 0.5609]],

        [[0.6023, 0.

[{'class': 'y_0', 'explanation': 'c_0 & c_1 & c_2 & c_3 & c_4 & c_5', 'count': 6}]


In [9]:
from torch_explain.nn.concepts import ConceptReasoningLayerMod
dcrm = ConceptReasoningLayerMod(embedding_size, y_train.shape[1], log=False)
dcrm(c_emb, c_pred).shape

torch.Size([6, 2])

In [30]:
from torch_explain.nn.concepts import SignRelevanceAttention, SignRelevanceNet, WeightedMerger
sra = SignRelevanceAttention(c_pred.shape[1], y_train.shape[1])
srn = SignRelevanceNet(c_pred.shape[1], y_train.shape[1])
wm = WeightedMerger(c_pred.shape[1])

tensor([[[0.6354, 0.7199],
         [0.9998, 0.9998],
         [0.5767, 0.6782],
         [0.9948, 0.9961],
         [0.9640, 0.9726],
         [0.7018, 0.7699]],

        [[0.6355, 0.7198],
         [0.9998, 0.9998],
         [0.5758, 0.6773],
         [0.9948, 0.9961],
         [0.9640, 0.9726],
         [0.7017, 0.7698]],

        [[0.6358, 0.7197],
         [0.9998, 0.9998],
         [0.5781, 0.6755],
         [0.9948, 0.9960],
         [0.9640, 0.9727],
         [0.7021, 0.7701]],

        [[0.6351, 0.7202],
         [0.9998, 0.9998],
         [0.5780, 0.6793],
         [0.9949, 0.9961],
         [0.9640, 0.9726],
         [0.7019, 0.7704]],

        [[0.6355, 0.7198],
         [0.9998, 0.9998],
         [0.5767, 0.6778],
         [0.9948, 0.9961],
         [0.9640, 0.9726],
         [0.7017, 0.7700]],

        [[0.6354, 0.7199],
         [0.9998, 0.9998],
         [0.5754, 0.6778],
         [0.9948, 0.9961],
         [0.9640, 0.9725],
         [0.7017, 0.7699]]], grad_fn=<AddBack

In [20]:
preds, sign_attn, filter_attn = task_predictor(c_emb, c_pred, return_attn=True)

In [None]:
c_emb.shape , c_pred.shape

In [7]:
gex0 = [item['explanation'] for item in global_explanations if item['class'] == 'y_0']
gex1 = [item['explanation'] for item in global_explanations if item['class'] == 'y_1']

In [8]:
aggregated_exp0 = "|".join(gex0)
aggregated_exp1 = "|".join(gex1)

In [9]:
from sympy.logic.boolalg import to_dnf
to_dnf(aggregated_exp0, simplify=True)

(c_0 & c_1) | (~c_0 & ~c_1)

In [10]:
to_dnf(aggregated_exp1, simplify=True)

(c_0 & ~c_1) | (c_1 & ~c_0)