In [1]:
import torch.nn.functional as F
from torch_explain.nn.concepts import ConceptReasoningLayer, IntpLinearLayer, ConceptReasoningLayerMod, ReasoningLinearLayer
import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.trigonometry(100)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
    x, c, y, test_size=0.33, random_state=42)

embedding_size = 8
concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    te.nn.ConceptEmbedding(10, c.shape[1], embedding_size),
)


In [2]:

# -------------------------------------#

y_train = F.one_hot(y_train.long().ravel()).float()
y_test = F.one_hot(y_test.long().ravel()).float()

In [3]:
y_train.shape

torch.Size([67, 2])

In [4]:
concept_encoder(x_train)[0].shape

torch.Size([67, 3, 8])

In [6]:

task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1])
model = torch.nn.Sequential(concept_encoder, task_predictor)

model.train()
# optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
# loss = torch.nn.BCELoss()
# c_emb, c_pred = concept_encoder(x_train)
# print(c_emb.shape, c_pred.shape)
# y_pred , sign_attn , filter_attn = task_predictor(c_emb, c_pred, return_attn=True)
# print(y_pred.shape, sign_attn.shape, filter_attn.shape)

# global_explainer = ReasoningLinearLayer(sign_attn.shape[1], filter_attn.shape[1], y_train.shape[1], log=True, modality='Attention')
# global_explainer.train()
# global_explainer(sign_attn, filter_attn, y_train)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
# loss_form = torch.nn.BCELoss()
loss_form = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(1):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_emb, c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_emb, c_pred)
    

    # compute loss
    concept_loss = loss_form(c_pred, c_train)
    task_loss = loss_form(y_pred, y_train)
    loss = concept_loss + 0.5*task_loss

    loss.backward()
    optimizer.step()

local_explanations = task_predictor.explain(c_emb, c_pred, 'local')
global_explanations = task_predictor.explain(c_emb, c_pred, 'global')

# # print(local_explanations)
# print(global_explanations)


In [9]:
from torch_explain.nn.concepts import ConceptReasoningLayerMod
dcrm = ConceptReasoningLayerMod(embedding_size, y_train.shape[1], log=False)
dcrm(c_emb, c_pred).shape

torch.Size([6, 2])

In [30]:
from torch_explain.nn.concepts import SignRelevanceAttention, SignRelevanceNet, WeightedMerger
sra = SignRelevanceAttention(c_pred.shape[1], y_train.shape[1])
srn = SignRelevanceNet(c_pred.shape[1], y_train.shape[1])
wm = WeightedMerger(c_pred.shape[1])

tensor([[[0.6354, 0.7199],
         [0.9998, 0.9998],
         [0.5767, 0.6782],
         [0.9948, 0.9961],
         [0.9640, 0.9726],
         [0.7018, 0.7699]],

        [[0.6355, 0.7198],
         [0.9998, 0.9998],
         [0.5758, 0.6773],
         [0.9948, 0.9961],
         [0.9640, 0.9726],
         [0.7017, 0.7698]],

        [[0.6358, 0.7197],
         [0.9998, 0.9998],
         [0.5781, 0.6755],
         [0.9948, 0.9960],
         [0.9640, 0.9727],
         [0.7021, 0.7701]],

        [[0.6351, 0.7202],
         [0.9998, 0.9998],
         [0.5780, 0.6793],
         [0.9949, 0.9961],
         [0.9640, 0.9726],
         [0.7019, 0.7704]],

        [[0.6355, 0.7198],
         [0.9998, 0.9998],
         [0.5767, 0.6778],
         [0.9948, 0.9961],
         [0.9640, 0.9726],
         [0.7017, 0.7700]],

        [[0.6354, 0.7199],
         [0.9998, 0.9998],
         [0.5754, 0.6778],
         [0.9948, 0.9961],
         [0.9640, 0.9725],
         [0.7017, 0.7699]]], grad_fn=<AddBack

In [14]:
global_explanations

[{'class': 'y_1', 'explanation': 'c_0 & c_1 & c_2', 'count': 31}]

In [7]:
preds, sign_attn, filter_attn = task_predictor(c_emb, c_pred, return_attn=True)

In [8]:
c_emb.shape , c_pred.shape

(torch.Size([67, 3, 8]), torch.Size([67, 3]))

In [9]:
gex0 = [item['explanation'] for item in global_explanations if item['class'] == 'y_0']
gex1 = [item['explanation'] for item in global_explanations if item['class'] == 'y_1']

In [10]:
aggregated_exp0 = "|".join(gex0)
aggregated_exp1 = "|".join(gex1)

In [11]:
from sympy.logic.boolalg import to_dnf
to_dnf(aggregated_exp0, simplify=True)

SympifyError: Sympify of expression 'could not parse ''' failed, because of exception being raised:
SyntaxError: invalid syntax (<string>, line 0)

In [10]:
to_dnf(aggregated_exp1, simplify=True)

(c_0 & ~c_1) | (c_1 & ~c_0)